brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1360 -1318
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
brainstate/init/_random_inits.py
CHANGED
@@ -14,486 +14,541 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
|
+
from __future__ import annotations
|
17
18
|
|
18
19
|
import math
|
19
20
|
|
20
|
-
import brainunit as
|
21
|
+
import brainunit as u
|
21
22
|
import jax.numpy as jnp
|
22
23
|
import numpy as np
|
23
24
|
|
24
25
|
from brainstate import environ, random
|
25
|
-
from brainstate.typing import ArrayLike
|
26
|
+
from brainstate.typing import ArrayLike, SeedOrKey, DTypeLike
|
26
27
|
from ._base import Initializer, to_size
|
27
28
|
|
28
29
|
__all__ = [
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
30
|
+
'Normal',
|
31
|
+
'TruncatedNormal',
|
32
|
+
'Uniform',
|
33
|
+
'VarianceScaling',
|
34
|
+
'KaimingUniform',
|
35
|
+
'KaimingNormal',
|
36
|
+
'XavierUniform',
|
37
|
+
'XavierNormal',
|
38
|
+
'LecunUniform',
|
39
|
+
'LecunNormal',
|
40
|
+
'Orthogonal',
|
41
|
+
'DeltaOrthogonal',
|
41
42
|
]
|
42
43
|
|
43
44
|
|
44
45
|
def calculate_gain(nonlinearity, param=None):
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
46
|
+
r"""Return the recommended gain value for the given nonlinearity function.
|
47
|
+
The values are as follows:
|
48
|
+
|
49
|
+
================= ====================================================
|
50
|
+
nonlinearity gain
|
51
|
+
================= ====================================================
|
52
|
+
Linear / Identity :math:`1`
|
53
|
+
Conv{1,2,3}D :math:`1`
|
54
|
+
Sigmoid :math:`1`
|
55
|
+
Tanh :math:`\frac{5}{3}`
|
56
|
+
ReLU :math:`\sqrt{2}`
|
57
|
+
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
|
58
|
+
SELU :math:`\frac{3}{4}`
|
59
|
+
================= ====================================================
|
60
|
+
|
61
|
+
.. warning::
|
62
|
+
In order to implement `Self-Normalizing Neural Networks`_ ,
|
63
|
+
you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
|
64
|
+
This gives the initial weights a variance of ``1 / N``,
|
65
|
+
which is necessary to induce a stable fixed point in the forward pass.
|
66
|
+
In contrast, the default gain for ``SELU`` sacrifices the normalisation
|
67
|
+
effect for more stable gradient flow in rectangular layers.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
nonlinearity: the non-linear function (`nn.functional` name)
|
71
|
+
param: optional parameter for the non-linear function
|
72
|
+
|
73
|
+
.. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
|
74
|
+
"""
|
75
|
+
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
76
|
+
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
77
|
+
return 1
|
78
|
+
elif nonlinearity == 'tanh':
|
79
|
+
return 5.0 / 3
|
80
|
+
elif nonlinearity == 'relu':
|
81
|
+
return math.sqrt(2.0)
|
82
|
+
elif nonlinearity == 'leaky_relu':
|
83
|
+
if param is None:
|
84
|
+
negative_slope = 0.01
|
85
|
+
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
86
|
+
# True/False are instances of int, hence check above
|
87
|
+
negative_slope = param
|
88
|
+
else:
|
89
|
+
raise ValueError("negative_slope {} not a valid number".format(param))
|
90
|
+
return math.sqrt(2.0 / (1 + negative_slope ** 2))
|
91
|
+
elif nonlinearity == 'selu':
|
92
|
+
return 3.0 / 4
|
87
93
|
else:
|
88
|
-
|
89
|
-
return math.sqrt(2.0 / (1 + negative_slope ** 2))
|
90
|
-
elif nonlinearity == 'selu':
|
91
|
-
return 3.0 / 4
|
92
|
-
else:
|
93
|
-
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
94
|
+
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
94
95
|
|
95
96
|
|
96
97
|
def _format_shape(shape):
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
98
|
+
if isinstance(shape, int):
|
99
|
+
return (shape,)
|
100
|
+
if len(shape) == 0:
|
101
|
+
raise ValueError('Please provide shape.')
|
102
|
+
if len(shape) == 1:
|
103
|
+
if isinstance(shape[0], (tuple, list)):
|
104
|
+
return shape[0]
|
105
|
+
else:
|
106
|
+
return shape
|
104
107
|
else:
|
105
|
-
|
106
|
-
else:
|
107
|
-
return shape
|
108
|
+
return shape
|
108
109
|
|
109
110
|
|
110
111
|
def _compute_fans(shape, in_axis=-2, out_axis=-1):
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
112
|
+
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
113
|
+
fan_in = shape[in_axis] * receptive_field_size
|
114
|
+
fan_out = shape[out_axis] * receptive_field_size
|
115
|
+
return fan_in, fan_out
|
115
116
|
|
116
117
|
|
117
118
|
class Normal(Initializer):
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
119
|
+
"""Initialize weights with normal distribution.
|
120
|
+
|
121
|
+
Parameters
|
122
|
+
----------
|
123
|
+
scale : float
|
124
|
+
The gain of the derivation of the normal distribution.
|
125
|
+
|
126
|
+
"""
|
127
|
+
__module__ = 'brainstate.init'
|
128
|
+
|
129
|
+
def __init__(
|
130
|
+
self,
|
131
|
+
mean: ArrayLike = 0.,
|
132
|
+
scale: ArrayLike = 1.,
|
133
|
+
unit: u.Unit = u.UNITLESS,
|
134
|
+
seed: SeedOrKey = None
|
135
|
+
):
|
136
|
+
super().__init__()
|
137
|
+
self.scale = scale
|
138
|
+
self.mean = mean
|
139
|
+
self.rng = random.default_rng(seed)
|
140
|
+
self.unit = unit
|
141
|
+
|
142
|
+
def __call__(self, shape, dtype: DTypeLike = None):
|
143
|
+
shape = to_size(shape)
|
144
|
+
dtype = dtype or environ.dftype()
|
145
|
+
weights = self.rng.normal(size=shape, loc=self.mean, scale=self.scale, dtype=dtype)
|
146
|
+
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
140
147
|
|
141
148
|
|
142
149
|
class TruncatedNormal(Initializer):
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
150
|
+
"""Initialize weights with truncated normal distribution.
|
151
|
+
|
152
|
+
Parameters
|
153
|
+
----------
|
154
|
+
loc : float, ndarray
|
155
|
+
Mean ("centre") of the distribution before truncating. Note that
|
156
|
+
the mean of the truncated distribution will not be exactly equal
|
157
|
+
to ``loc``.
|
158
|
+
scale : float
|
159
|
+
The standard deviation of the normal distribution before truncating.
|
160
|
+
lower : float, ndarray
|
161
|
+
A float or array of floats representing the lower bound for
|
162
|
+
truncation. Must be broadcast-compatible with ``upper``.
|
163
|
+
upper : float, ndarray
|
164
|
+
A float or array of floats representing the upper bound for
|
165
|
+
truncation. Must be broadcast-compatible with ``lower``.
|
166
|
+
|
167
|
+
"""
|
168
|
+
__module__ = 'brainstate.init'
|
169
|
+
|
170
|
+
def __init__(
|
171
|
+
self,
|
172
|
+
loc: ArrayLike = 0.,
|
173
|
+
scale: ArrayLike = 1.,
|
174
|
+
unit: u.Unit = u.UNITLESS,
|
175
|
+
lower: ArrayLike = None,
|
176
|
+
upper: ArrayLike = None,
|
177
|
+
seed: SeedOrKey = None,
|
178
|
+
):
|
179
|
+
super().__init__()
|
180
|
+
assert scale > 0, '`scale` must be positive.'
|
181
|
+
self.scale = scale
|
182
|
+
self.loc = loc
|
183
|
+
self.lower = lower
|
184
|
+
self.upper = upper
|
185
|
+
self.rng = random.default_rng(seed)
|
186
|
+
self.unit = unit
|
187
|
+
|
188
|
+
def __call__(self, shape, dtype: DTypeLike = None, ):
|
189
|
+
dtype = dtype or environ.dftype()
|
190
|
+
weights = self.rng.truncated_normal(
|
191
|
+
size=shape,
|
192
|
+
scale=self.scale,
|
193
|
+
lower=self.lower,
|
194
|
+
upper=self.upper,
|
195
|
+
loc=self.loc,
|
196
|
+
dtype=dtype
|
197
|
+
)
|
198
|
+
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
185
199
|
|
186
200
|
|
187
201
|
class Gamma(Initializer):
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
202
|
+
"""Initialize weights with Gamma distribution.
|
203
|
+
|
204
|
+
Parameters
|
205
|
+
----------
|
206
|
+
shape: float, Array
|
207
|
+
Shape parameter.
|
208
|
+
scale: float, Array
|
209
|
+
The gain of the derivation of the Gamma distribution.
|
210
|
+
|
211
|
+
"""
|
212
|
+
__module__ = 'brainstate.init'
|
213
|
+
|
214
|
+
def __init__(
|
215
|
+
self,
|
216
|
+
shape: ArrayLike,
|
217
|
+
unit: u.Unit = u.UNITLESS,
|
218
|
+
scale: ArrayLike = None,
|
219
|
+
seed: SeedOrKey = None
|
220
|
+
):
|
221
|
+
self.shape = shape
|
222
|
+
self.scale = scale
|
223
|
+
self.rng = random.default_rng(seed)
|
224
|
+
self.unit = unit
|
225
|
+
|
226
|
+
def __call__(self, shape, dtype: DTypeLike = None, ):
|
227
|
+
shape = to_size(shape)
|
228
|
+
dtype = dtype or environ.dftype()
|
229
|
+
weights = self.rng.gamma(self.shape, scale=self.scale, size=shape, dtype=dtype)
|
230
|
+
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
211
231
|
|
212
232
|
|
213
233
|
class Exponential(Initializer):
|
214
|
-
|
234
|
+
"""Initialize weights with Gamma distribution.
|
215
235
|
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
236
|
+
Parameters
|
237
|
+
----------
|
238
|
+
scale: float, Array
|
239
|
+
The gain of the derivation of the Exponential distribution.
|
220
240
|
|
221
|
-
|
241
|
+
"""
|
242
|
+
__module__ = 'brainstate.init'
|
222
243
|
|
223
|
-
|
224
|
-
|
225
|
-
|
244
|
+
def __init__(
|
245
|
+
self,
|
246
|
+
scale: ArrayLike = None,
|
247
|
+
seed: SeedOrKey = None,
|
248
|
+
unit: u.Unit = u.UNITLESS,
|
249
|
+
):
|
250
|
+
self.scale = scale
|
251
|
+
self.rng = random.default_rng(seed)
|
252
|
+
self.unit = unit
|
226
253
|
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
def __repr__(self):
|
233
|
-
return f'{self.__class__.__name__}(scale={self.scale}, dtype={self.dtype})'
|
254
|
+
def __call__(self, shape, dtype: DTypeLike = None, ):
|
255
|
+
shape = to_size(shape)
|
256
|
+
dtype = dtype or environ.dftype()
|
257
|
+
weights = self.rng.exponential(scale=self.scale, size=shape, dtype=dtype)
|
258
|
+
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
234
259
|
|
235
260
|
|
236
261
|
class Uniform(Initializer):
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
262
|
+
"""Initialize weights with uniform distribution.
|
263
|
+
|
264
|
+
Parameters
|
265
|
+
----------
|
266
|
+
min_val : float
|
267
|
+
The lower limit of the uniform distribution.
|
268
|
+
max_val : float
|
269
|
+
The upper limit of the uniform distribution.
|
270
|
+
"""
|
271
|
+
__module__ = 'brainstate.init'
|
272
|
+
|
273
|
+
def __init__(
|
274
|
+
self,
|
275
|
+
min_val: ArrayLike = 0.,
|
276
|
+
max_val: ArrayLike = 1.,
|
277
|
+
seed: SeedOrKey = None,
|
278
|
+
unit: u.Unit = u.UNITLESS,
|
279
|
+
):
|
280
|
+
super(Uniform, self).__init__()
|
281
|
+
self.min_val = min_val
|
282
|
+
self.max_val = max_val
|
283
|
+
self.rng = random.default_rng(seed)
|
284
|
+
self.unit = unit
|
285
|
+
|
286
|
+
def __call__(self, shape, dtype: DTypeLike = None, ):
|
287
|
+
shape = to_size(shape)
|
288
|
+
dtype = dtype or environ.dftype()
|
289
|
+
weights = self.rng.uniform(low=self.min_val, high=self.max_val, size=shape, dtype=dtype)
|
290
|
+
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
260
291
|
|
261
292
|
|
262
293
|
class VarianceScaling(Initializer):
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
blank = ' ' * len(name)
|
310
|
-
return (f'{name}(scale={self.scale}, mode={self.mode}, in_axis={self.in_axis}, \n'
|
311
|
-
f'{blank}out_axis={self.out_axis}, distribution={self.distribution}, dtype={self.dtype})')
|
294
|
+
__module__ = 'brainstate.init'
|
295
|
+
|
296
|
+
def __init__(
|
297
|
+
self,
|
298
|
+
scale: ArrayLike,
|
299
|
+
mode: str,
|
300
|
+
distribution: str,
|
301
|
+
in_axis: int = -2,
|
302
|
+
out_axis: int = -1,
|
303
|
+
seed: SeedOrKey = None,
|
304
|
+
unit: u.Unit = u.UNITLESS,
|
305
|
+
):
|
306
|
+
assert mode in ['fan_in', 'fan_out', 'fan_avg']
|
307
|
+
assert distribution in ['truncated_normal', 'normal', 'uniform']
|
308
|
+
self.scale = scale
|
309
|
+
self.mode = mode
|
310
|
+
self.in_axis = in_axis
|
311
|
+
self.out_axis = out_axis
|
312
|
+
self.distribution = distribution
|
313
|
+
self.rng = random.default_rng(seed)
|
314
|
+
self.unit = unit
|
315
|
+
|
316
|
+
def __call__(self, shape, dtype: DTypeLike = None, ):
|
317
|
+
shape = to_size(shape)
|
318
|
+
dtype = dtype or environ.dftype()
|
319
|
+
fan_in, fan_out = _compute_fans(shape, in_axis=self.in_axis, out_axis=self.out_axis)
|
320
|
+
if self.mode == "fan_in":
|
321
|
+
denominator = fan_in
|
322
|
+
elif self.mode == "fan_out":
|
323
|
+
denominator = fan_out
|
324
|
+
elif self.mode == "fan_avg":
|
325
|
+
denominator = (fan_in + fan_out) / 2
|
326
|
+
else:
|
327
|
+
raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
|
328
|
+
variance = (self.scale / denominator).astype(dtype)
|
329
|
+
if self.distribution == "truncated_normal":
|
330
|
+
stddev = (jnp.sqrt(variance) / .87962566103423978).astype(dtype)
|
331
|
+
res = self.rng.truncated_normal(-2, 2, shape, dtype=dtype) * stddev
|
332
|
+
elif self.distribution == "normal":
|
333
|
+
res = self.rng.randn(*shape, dtype=dtype) * jnp.sqrt(variance).astype(dtype)
|
334
|
+
elif self.distribution == "uniform":
|
335
|
+
res = (self.rng.uniform(low=-1, high=1, size=shape, dtype=dtype) *
|
336
|
+
jnp.sqrt(3 * variance).astype(dtype))
|
337
|
+
else:
|
338
|
+
raise ValueError("invalid distribution for variance scaling initializer")
|
339
|
+
return u.maybe_decimal(u.Quantity(res, unit=self.unit))
|
312
340
|
|
313
341
|
|
314
342
|
class KaimingUniform(VarianceScaling):
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
343
|
+
__module__ = 'brainstate.init'
|
344
|
+
|
345
|
+
def __init__(
|
346
|
+
self,
|
347
|
+
scale: float = 2.0,
|
348
|
+
mode: str = "fan_in",
|
349
|
+
distribution: str = "uniform",
|
350
|
+
in_axis: int = -2,
|
351
|
+
out_axis: int = -1,
|
352
|
+
seed: SeedOrKey = None,
|
353
|
+
unit: u.Unit = u.UNITLESS,
|
354
|
+
):
|
355
|
+
super().__init__(scale,
|
356
|
+
mode,
|
357
|
+
distribution,
|
358
|
+
in_axis=in_axis,
|
359
|
+
out_axis=out_axis,
|
360
|
+
seed=seed,
|
361
|
+
unit=unit)
|
330
362
|
|
331
363
|
|
332
364
|
class KaimingNormal(VarianceScaling):
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
365
|
+
__module__ = 'brainstate.init'
|
366
|
+
|
367
|
+
def __init__(
|
368
|
+
self,
|
369
|
+
scale: float = 2.0,
|
370
|
+
mode: str = "fan_in",
|
371
|
+
distribution: str = "truncated_normal",
|
372
|
+
in_axis: int = -2,
|
373
|
+
out_axis: int = -1,
|
374
|
+
seed: SeedOrKey = None,
|
375
|
+
unit: u.Unit = u.UNITLESS,
|
376
|
+
):
|
377
|
+
super().__init__(scale,
|
378
|
+
mode,
|
379
|
+
distribution,
|
380
|
+
in_axis=in_axis,
|
381
|
+
out_axis=out_axis,
|
382
|
+
seed=seed,
|
383
|
+
unit=unit)
|
348
384
|
|
349
385
|
|
350
386
|
class XavierUniform(VarianceScaling):
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
387
|
+
__module__ = 'brainstate.init'
|
388
|
+
|
389
|
+
def __init__(
|
390
|
+
self,
|
391
|
+
scale: float = 1.0,
|
392
|
+
mode: str = "fan_avg",
|
393
|
+
distribution: str = "uniform",
|
394
|
+
in_axis: int = -2,
|
395
|
+
out_axis: int = -1,
|
396
|
+
seed: SeedOrKey = None,
|
397
|
+
unit: u.Unit = u.UNITLESS,
|
398
|
+
):
|
399
|
+
super().__init__(scale,
|
400
|
+
mode,
|
401
|
+
distribution,
|
402
|
+
in_axis=in_axis,
|
403
|
+
out_axis=out_axis,
|
404
|
+
seed=seed,
|
405
|
+
unit=unit)
|
366
406
|
|
367
407
|
|
368
408
|
class XavierNormal(VarianceScaling):
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
409
|
+
__module__ = 'brainstate.init'
|
410
|
+
|
411
|
+
def __init__(
|
412
|
+
self,
|
413
|
+
scale: float = 1.0,
|
414
|
+
mode: str = "fan_avg",
|
415
|
+
distribution: str = "truncated_normal",
|
416
|
+
in_axis: int = -2,
|
417
|
+
out_axis: int = -1,
|
418
|
+
seed: SeedOrKey = None,
|
419
|
+
unit: u.Unit = u.UNITLESS,
|
420
|
+
):
|
421
|
+
super().__init__(scale,
|
422
|
+
mode,
|
423
|
+
distribution,
|
424
|
+
in_axis=in_axis,
|
425
|
+
out_axis=out_axis,
|
426
|
+
seed=seed,
|
427
|
+
unit=unit)
|
384
428
|
|
385
429
|
|
386
430
|
class LecunUniform(VarianceScaling):
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
431
|
+
__module__ = 'brainstate.init'
|
432
|
+
|
433
|
+
def __init__(
|
434
|
+
self,
|
435
|
+
scale: float = 1.0,
|
436
|
+
mode: str = "fan_in",
|
437
|
+
distribution: str = "uniform",
|
438
|
+
in_axis: int = -2,
|
439
|
+
out_axis: int = -1,
|
440
|
+
seed: SeedOrKey = None,
|
441
|
+
unit: u.Unit = u.UNITLESS,
|
442
|
+
):
|
443
|
+
super().__init__(scale,
|
444
|
+
mode,
|
445
|
+
distribution,
|
446
|
+
in_axis=in_axis,
|
447
|
+
out_axis=out_axis,
|
448
|
+
seed=seed,
|
449
|
+
unit=unit)
|
402
450
|
|
403
451
|
|
404
452
|
class LecunNormal(VarianceScaling):
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
453
|
+
__module__ = 'brainstate.init'
|
454
|
+
|
455
|
+
def __init__(
|
456
|
+
self,
|
457
|
+
scale: float = 1.0,
|
458
|
+
mode: str = "fan_in",
|
459
|
+
distribution: str = "truncated_normal",
|
460
|
+
in_axis: int = -2,
|
461
|
+
out_axis: int = -1,
|
462
|
+
seed: SeedOrKey = None,
|
463
|
+
unit: u.Unit = u.UNITLESS,
|
464
|
+
):
|
465
|
+
super().__init__(scale,
|
466
|
+
mode,
|
467
|
+
distribution,
|
468
|
+
in_axis=in_axis,
|
469
|
+
out_axis=out_axis,
|
470
|
+
seed=seed,
|
471
|
+
unit=unit)
|
420
472
|
|
421
473
|
|
422
474
|
class Orthogonal(Initializer):
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
|
475
|
+
"""
|
476
|
+
Construct an initializer for uniformly distributed orthogonal matrices.
|
477
|
+
|
478
|
+
If the shape is not square, the matrix will have orthonormal rows or columns
|
479
|
+
depending on which side is smaller.
|
480
|
+
"""
|
481
|
+
__module__ = 'brainstate.init'
|
482
|
+
|
483
|
+
def __init__(
|
484
|
+
self,
|
485
|
+
scale: ArrayLike = 1.,
|
486
|
+
axis: int = -1,
|
487
|
+
seed: SeedOrKey = None,
|
488
|
+
unit: u.Unit = u.UNITLESS,
|
489
|
+
):
|
490
|
+
super().__init__()
|
491
|
+
self.scale = scale
|
492
|
+
self.axis = axis
|
493
|
+
self.rng = random.default_rng(seed)
|
494
|
+
self.unit = unit
|
495
|
+
|
496
|
+
def __call__(self, shape, dtype: DTypeLike = None, ):
|
497
|
+
dtype = dtype or environ.dftype()
|
498
|
+
shape = to_size(shape)
|
499
|
+
n_rows = shape[self.axis]
|
500
|
+
n_cols = np.prod(shape) // n_rows
|
501
|
+
matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
|
502
|
+
norm_dst = self.rng.normal(size=matrix_shape, dtype=dtype)
|
503
|
+
|
504
|
+
q_mat, r_mat = jnp.linalg.qr(norm_dst)
|
505
|
+
# Enforce Q is uniformly distributed
|
506
|
+
q_mat *= jnp.sign(jnp.diag(r_mat))
|
507
|
+
if n_rows < n_cols:
|
508
|
+
q_mat = q_mat.T
|
509
|
+
q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
|
510
|
+
q_mat = jnp.moveaxis(q_mat, 0, self.axis)
|
511
|
+
r = jnp.asarray(self.scale, dtype=dtype) * q_mat
|
512
|
+
return u.maybe_decimal(u.Quantity(r, unit=self.unit))
|
462
513
|
|
463
514
|
|
464
515
|
class DeltaOrthogonal(Initializer):
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
516
|
+
"""
|
517
|
+
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
|
518
|
+
|
519
|
+
The shape must be 3D, 4D or 5D.
|
520
|
+
"""
|
521
|
+
__module__ = 'brainstate.init'
|
522
|
+
|
523
|
+
def __init__(
|
524
|
+
self,
|
525
|
+
scale: ArrayLike = 1.0,
|
526
|
+
axis: int = -1,
|
527
|
+
seed: SeedOrKey = None,
|
528
|
+
unit: u.Unit = u.UNITLESS,
|
529
|
+
):
|
530
|
+
super().__init__()
|
531
|
+
self.scale = scale
|
532
|
+
self.axis = axis
|
533
|
+
self.orghogonal = Orthogonal(scale=scale, axis=axis, seed=seed)
|
534
|
+
self.unit = unit
|
535
|
+
|
536
|
+
def __call__(self, shape, dtype: DTypeLike = None, ):
|
537
|
+
shape = to_size(shape)
|
538
|
+
dtype = dtype or environ.dftype()
|
539
|
+
if len(shape) not in [3, 4, 5]:
|
540
|
+
raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
|
541
|
+
if shape[-1] < shape[-2]:
|
542
|
+
raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
|
543
|
+
ortho_matrix = u.Quantity(self.orghogonal(shape[-2:]))
|
544
|
+
W = u.Quantity(u.math.zeros(shape, dtype=dtype), unit=u.get_unit(ortho_matrix))
|
545
|
+
if len(shape) == 3:
|
546
|
+
k = shape[0]
|
547
|
+
W = W.at[(k - 1) // 2].set(ortho_matrix)
|
548
|
+
elif len(shape) == 4:
|
549
|
+
k1, k2 = shape[:2]
|
550
|
+
W = W.at[(k1 - 1) // 2, (k2 - 1) // 2].set(ortho_matrix)
|
551
|
+
else:
|
552
|
+
k1, k2, k3 = shape[:3]
|
553
|
+
W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
|
554
|
+
return u.maybe_decimal(u.Quantity(W.mantissa, unit=self.unit))
|