brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -13,19 +13,27 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
16
|
import math
|
17
|
+
from typing import Optional, Tuple
|
18
|
+
from typing import Union, Callable, Sequence
|
19
19
|
|
20
20
|
import brainunit as u
|
21
|
+
import jax
|
21
22
|
import jax.numpy as jnp
|
22
23
|
import numpy as np
|
23
24
|
|
24
25
|
from brainstate import environ, random
|
25
|
-
from brainstate.
|
26
|
-
from .
|
26
|
+
from brainstate._state import State
|
27
|
+
from brainstate._utils import set_module_as
|
28
|
+
from brainstate.typing import ArrayLike, SeedOrKey
|
29
|
+
from brainstate.util import PrettyRepr, PrettyType, PrettyAttr
|
27
30
|
|
28
31
|
__all__ = [
|
32
|
+
'param',
|
33
|
+
'calculate_init_gain',
|
34
|
+
'ZeroInit',
|
35
|
+
'Constant',
|
36
|
+
'Identity',
|
29
37
|
'Normal',
|
30
38
|
'TruncatedNormal',
|
31
39
|
'Uniform',
|
@@ -41,7 +49,170 @@ __all__ = [
|
|
41
49
|
]
|
42
50
|
|
43
51
|
|
44
|
-
|
52
|
+
class Initializer(PrettyRepr):
|
53
|
+
"""
|
54
|
+
Base class for initializers.
|
55
|
+
"""
|
56
|
+
__module__ = 'brainstate.nn'
|
57
|
+
|
58
|
+
def __call__(self, *args, **kwargs):
|
59
|
+
raise NotImplementedError
|
60
|
+
|
61
|
+
def __pretty_repr__(self):
|
62
|
+
"""
|
63
|
+
Pretty repr for the object.
|
64
|
+
"""
|
65
|
+
yield PrettyType(type=type(self))
|
66
|
+
for name, value in vars(self).items():
|
67
|
+
if name.startswith('_'):
|
68
|
+
continue
|
69
|
+
yield PrettyAttr(name, repr(value))
|
70
|
+
|
71
|
+
|
72
|
+
def to_size(x) -> Optional[Tuple[int]]:
|
73
|
+
if isinstance(x, (tuple, list)):
|
74
|
+
return tuple(x)
|
75
|
+
if isinstance(x, (int, np.integer)):
|
76
|
+
return (x,)
|
77
|
+
if x is None:
|
78
|
+
return x
|
79
|
+
raise ValueError(f'Cannot make a size for {x}')
|
80
|
+
|
81
|
+
|
82
|
+
def _is_scalar(x):
|
83
|
+
return u.math.isscalar(x)
|
84
|
+
|
85
|
+
|
86
|
+
def are_broadcastable_shapes(shape1, shape2):
|
87
|
+
"""
|
88
|
+
Check if two shapes are broadcastable.
|
89
|
+
|
90
|
+
Parameters:
|
91
|
+
- shape1: Tuple[int], the shape of the first array.
|
92
|
+
- shape2: Tuple[int], the shape of the second array.
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
- bool: True if shapes are broadcastable, False otherwise.
|
96
|
+
"""
|
97
|
+
# Reverse the shapes to compare from the last dimension
|
98
|
+
shape1_reversed = shape1[::-1]
|
99
|
+
shape2_reversed = shape2[::-1]
|
100
|
+
|
101
|
+
# Iterate over the dimensions of the shorter shape
|
102
|
+
for dim1, dim2 in zip(shape1_reversed, shape2_reversed):
|
103
|
+
# Check if the dimensions are not equal and neither is 1
|
104
|
+
if dim1 != dim2 and 1 not in (dim1, dim2):
|
105
|
+
return False
|
106
|
+
|
107
|
+
# If all dimensions are compatible, the shapes are broadcastable
|
108
|
+
return True
|
109
|
+
|
110
|
+
|
111
|
+
def _expand_params_to_match_sizes(params, sizes):
|
112
|
+
"""
|
113
|
+
Expand the dimensions of params to match the dimensions of sizes.
|
114
|
+
|
115
|
+
Parameters:
|
116
|
+
- params: jax.Array or np.ndarray, the parameter array to be expanded.
|
117
|
+
- sizes: tuple[int] or list[int], the target shape dimensions.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
- Expanded params with dimensions matching sizes.
|
121
|
+
"""
|
122
|
+
params_dim = params.ndim
|
123
|
+
sizes_dim = len(sizes)
|
124
|
+
dim_diff = sizes_dim - params_dim
|
125
|
+
|
126
|
+
# Add new axes to params if it has fewer dimensions than sizes
|
127
|
+
for _ in range(dim_diff):
|
128
|
+
params = u.math.expand_dims(params, axis=0) # Add new axis at the last dimension
|
129
|
+
return params
|
130
|
+
|
131
|
+
|
132
|
+
@set_module_as('brainstate.nn')
|
133
|
+
def param(
|
134
|
+
parameter: Union[Callable, ArrayLike, State],
|
135
|
+
sizes: Union[int, Sequence[int]],
|
136
|
+
batch_size: Optional[int] = None,
|
137
|
+
allow_none: bool = True,
|
138
|
+
allow_scalar: bool = True,
|
139
|
+
):
|
140
|
+
"""Initialize parameters.
|
141
|
+
|
142
|
+
Parameters
|
143
|
+
----------
|
144
|
+
parameter: callable, ArrayLike, State
|
145
|
+
The initialization of the parameter.
|
146
|
+
- If it is None, the created parameter will be None.
|
147
|
+
- If it is a callable function :math:`f`, the ``f(size)`` will be returned.
|
148
|
+
- If it is an instance of :py:class:`init.Initializer``, the ``f(size)`` will be returned.
|
149
|
+
- If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``.
|
150
|
+
sizes: int, sequence of int
|
151
|
+
The shape of the parameter.
|
152
|
+
batch_size: int
|
153
|
+
The batch size.
|
154
|
+
allow_none: bool
|
155
|
+
Whether allow the parameter is None.
|
156
|
+
allow_scalar: bool
|
157
|
+
Whether allow the parameter is a scalar value.
|
158
|
+
|
159
|
+
Returns
|
160
|
+
-------
|
161
|
+
param: ArrayType, float, int, bool, None
|
162
|
+
The initialized parameter.
|
163
|
+
|
164
|
+
See Also
|
165
|
+
--------
|
166
|
+
noise, state
|
167
|
+
"""
|
168
|
+
# Check if the parameter is None
|
169
|
+
if parameter is None:
|
170
|
+
if allow_none:
|
171
|
+
return None
|
172
|
+
else:
|
173
|
+
raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or '
|
174
|
+
f'Callable function, but we got None. ')
|
175
|
+
|
176
|
+
# Check if the parameter is a scalar value
|
177
|
+
if allow_scalar and _is_scalar(parameter):
|
178
|
+
return parameter
|
179
|
+
|
180
|
+
# Convert sizes to a tuple
|
181
|
+
sizes = tuple(to_size(sizes))
|
182
|
+
|
183
|
+
# Check if the parameter is a callable function
|
184
|
+
if callable(parameter):
|
185
|
+
if batch_size is not None:
|
186
|
+
sizes = (batch_size,) + sizes
|
187
|
+
return parameter(sizes)
|
188
|
+
elif isinstance(parameter, (np.ndarray, jax.Array, u.Quantity, State)):
|
189
|
+
parameter = parameter
|
190
|
+
else:
|
191
|
+
raise ValueError(f'Unknown parameter type: {type(parameter)}')
|
192
|
+
|
193
|
+
# Check if the shape of the parameter matches the given size
|
194
|
+
if not are_broadcastable_shapes(parameter.shape, sizes):
|
195
|
+
raise ValueError(f'The shape of the parameter {parameter.shape} does not match with the given size {sizes}')
|
196
|
+
|
197
|
+
# Expand the parameter to match the given batch size
|
198
|
+
param_value = parameter.value if isinstance(parameter, State) else parameter
|
199
|
+
if batch_size is not None:
|
200
|
+
if param_value.ndim <= len(sizes):
|
201
|
+
# add a new axis to the params so that it matches the dimensionality of the given shape ``sizes``
|
202
|
+
param_value = _expand_params_to_match_sizes(param_value, sizes)
|
203
|
+
param_value = u.math.repeat(
|
204
|
+
u.math.expand_dims(param_value, axis=0),
|
205
|
+
batch_size,
|
206
|
+
axis=0
|
207
|
+
)
|
208
|
+
else:
|
209
|
+
if param_value.shape[0] != batch_size:
|
210
|
+
raise ValueError(f'The batch size of the parameter {param_value.shape[0]} '
|
211
|
+
f'does not match with the given batch size {batch_size}')
|
212
|
+
return type(parameter)(param_value) if isinstance(parameter, State) else param_value
|
213
|
+
|
214
|
+
|
215
|
+
def calculate_init_gain(nonlinearity, param=None):
|
45
216
|
r"""Return the recommended gain value for the given nonlinearity function.
|
46
217
|
The values are as follows:
|
47
218
|
|
@@ -123,7 +294,7 @@ class Normal(Initializer):
|
|
123
294
|
The gain of the derivation of the normal distribution.
|
124
295
|
|
125
296
|
"""
|
126
|
-
__module__ = 'brainstate.
|
297
|
+
__module__ = 'brainstate.nn'
|
127
298
|
|
128
299
|
def __init__(
|
129
300
|
self,
|
@@ -138,10 +309,11 @@ class Normal(Initializer):
|
|
138
309
|
self.rng = random.default_rng(seed)
|
139
310
|
self.unit = unit
|
140
311
|
|
141
|
-
def __call__(self, shape,
|
312
|
+
def __call__(self, shape, **kwargs):
|
142
313
|
shape = to_size(shape)
|
143
|
-
dtype = dtype
|
144
|
-
|
314
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
315
|
+
rng = kwargs.get('rng', self.rng)
|
316
|
+
weights = rng.normal(size=shape, loc=self.mean, scale=self.scale, dtype=dtype)
|
145
317
|
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
146
318
|
|
147
319
|
|
@@ -164,7 +336,7 @@ class TruncatedNormal(Initializer):
|
|
164
336
|
truncation. Must be broadcast-compatible with ``lower``.
|
165
337
|
|
166
338
|
"""
|
167
|
-
__module__ = 'brainstate.
|
339
|
+
__module__ = 'brainstate.nn'
|
168
340
|
|
169
341
|
def __init__(
|
170
342
|
self,
|
@@ -184,9 +356,10 @@ class TruncatedNormal(Initializer):
|
|
184
356
|
self.rng = random.default_rng(seed)
|
185
357
|
self.unit = unit
|
186
358
|
|
187
|
-
def __call__(self, shape,
|
188
|
-
dtype = dtype
|
189
|
-
|
359
|
+
def __call__(self, shape, **kwargs):
|
360
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
361
|
+
rng = kwargs.get('rng', self.rng)
|
362
|
+
weights = rng.truncated_normal(
|
190
363
|
size=shape,
|
191
364
|
scale=self.scale,
|
192
365
|
lower=self.lower,
|
@@ -208,7 +381,7 @@ class Gamma(Initializer):
|
|
208
381
|
The gain of the derivation of the Gamma distribution.
|
209
382
|
|
210
383
|
"""
|
211
|
-
__module__ = 'brainstate.
|
384
|
+
__module__ = 'brainstate.nn'
|
212
385
|
|
213
386
|
def __init__(
|
214
387
|
self,
|
@@ -222,10 +395,11 @@ class Gamma(Initializer):
|
|
222
395
|
self.rng = random.default_rng(seed)
|
223
396
|
self.unit = unit
|
224
397
|
|
225
|
-
def __call__(self, shape,
|
398
|
+
def __call__(self, shape, **kwargs):
|
226
399
|
shape = to_size(shape)
|
227
|
-
dtype = dtype
|
228
|
-
|
400
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
401
|
+
rng = kwargs.get('rng', self.rng)
|
402
|
+
weights = rng.gamma(self.shape, scale=self.scale, size=shape, dtype=dtype)
|
229
403
|
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
230
404
|
|
231
405
|
|
@@ -238,7 +412,7 @@ class Exponential(Initializer):
|
|
238
412
|
The gain of the derivation of the Exponential distribution.
|
239
413
|
|
240
414
|
"""
|
241
|
-
__module__ = 'brainstate.
|
415
|
+
__module__ = 'brainstate.nn'
|
242
416
|
|
243
417
|
def __init__(
|
244
418
|
self,
|
@@ -250,10 +424,11 @@ class Exponential(Initializer):
|
|
250
424
|
self.rng = random.default_rng(seed)
|
251
425
|
self.unit = unit
|
252
426
|
|
253
|
-
def __call__(self, shape,
|
427
|
+
def __call__(self, shape, **kwargs):
|
254
428
|
shape = to_size(shape)
|
255
|
-
dtype = dtype
|
256
|
-
|
429
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
430
|
+
rng = kwargs.get('rng', self.rng)
|
431
|
+
weights = rng.exponential(scale=self.scale, size=shape, dtype=dtype)
|
257
432
|
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
258
433
|
|
259
434
|
|
@@ -267,7 +442,7 @@ class Uniform(Initializer):
|
|
267
442
|
max_val : float
|
268
443
|
The upper limit of the uniform distribution.
|
269
444
|
"""
|
270
|
-
__module__ = 'brainstate.
|
445
|
+
__module__ = 'brainstate.nn'
|
271
446
|
|
272
447
|
def __init__(
|
273
448
|
self,
|
@@ -282,15 +457,16 @@ class Uniform(Initializer):
|
|
282
457
|
self.rng = random.default_rng(seed)
|
283
458
|
self.unit = unit
|
284
459
|
|
285
|
-
def __call__(self, shape,
|
460
|
+
def __call__(self, shape, **kwargs):
|
286
461
|
shape = to_size(shape)
|
287
|
-
dtype = dtype
|
288
|
-
|
462
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
463
|
+
rng = kwargs.get('rng', self.rng)
|
464
|
+
weights = rng.uniform(low=self.min_val, high=self.max_val, size=shape, dtype=dtype)
|
289
465
|
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
290
466
|
|
291
467
|
|
292
468
|
class VarianceScaling(Initializer):
|
293
|
-
__module__ = 'brainstate.
|
469
|
+
__module__ = 'brainstate.nn'
|
294
470
|
|
295
471
|
def __init__(
|
296
472
|
self,
|
@@ -312,9 +488,10 @@ class VarianceScaling(Initializer):
|
|
312
488
|
self.rng = random.default_rng(seed)
|
313
489
|
self.unit = unit
|
314
490
|
|
315
|
-
def __call__(self, shape,
|
491
|
+
def __call__(self, shape, **kwargs):
|
316
492
|
shape = to_size(shape)
|
317
|
-
dtype = dtype
|
493
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
494
|
+
rng = kwargs.get('rng', self.rng)
|
318
495
|
fan_in, fan_out = _compute_fans(shape, in_axis=self.in_axis, out_axis=self.out_axis)
|
319
496
|
if self.mode == "fan_in":
|
320
497
|
denominator = fan_in
|
@@ -327,11 +504,11 @@ class VarianceScaling(Initializer):
|
|
327
504
|
variance = (self.scale / denominator).astype(dtype)
|
328
505
|
if self.distribution == "truncated_normal":
|
329
506
|
stddev = (jnp.sqrt(variance) / .87962566103423978).astype(dtype)
|
330
|
-
res =
|
507
|
+
res = rng.truncated_normal(-2, 2, shape, dtype=dtype) * stddev
|
331
508
|
elif self.distribution == "normal":
|
332
|
-
res =
|
509
|
+
res = rng.randn(*shape, dtype=dtype) * jnp.sqrt(variance).astype(dtype)
|
333
510
|
elif self.distribution == "uniform":
|
334
|
-
res = (
|
511
|
+
res = (rng.uniform(low=-1, high=1, size=shape, dtype=dtype) *
|
335
512
|
jnp.sqrt(3 * variance).astype(dtype))
|
336
513
|
else:
|
337
514
|
raise ValueError("invalid distribution for variance scaling initializer")
|
@@ -339,7 +516,7 @@ class VarianceScaling(Initializer):
|
|
339
516
|
|
340
517
|
|
341
518
|
class KaimingUniform(VarianceScaling):
|
342
|
-
__module__ = 'brainstate.
|
519
|
+
__module__ = 'brainstate.nn'
|
343
520
|
|
344
521
|
def __init__(
|
345
522
|
self,
|
@@ -361,7 +538,7 @@ class KaimingUniform(VarianceScaling):
|
|
361
538
|
|
362
539
|
|
363
540
|
class KaimingNormal(VarianceScaling):
|
364
|
-
__module__ = 'brainstate.
|
541
|
+
__module__ = 'brainstate.nn'
|
365
542
|
|
366
543
|
def __init__(
|
367
544
|
self,
|
@@ -383,7 +560,7 @@ class KaimingNormal(VarianceScaling):
|
|
383
560
|
|
384
561
|
|
385
562
|
class XavierUniform(VarianceScaling):
|
386
|
-
__module__ = 'brainstate.
|
563
|
+
__module__ = 'brainstate.nn'
|
387
564
|
|
388
565
|
def __init__(
|
389
566
|
self,
|
@@ -405,7 +582,7 @@ class XavierUniform(VarianceScaling):
|
|
405
582
|
|
406
583
|
|
407
584
|
class XavierNormal(VarianceScaling):
|
408
|
-
__module__ = 'brainstate.
|
585
|
+
__module__ = 'brainstate.nn'
|
409
586
|
|
410
587
|
def __init__(
|
411
588
|
self,
|
@@ -427,7 +604,7 @@ class XavierNormal(VarianceScaling):
|
|
427
604
|
|
428
605
|
|
429
606
|
class LecunUniform(VarianceScaling):
|
430
|
-
__module__ = 'brainstate.
|
607
|
+
__module__ = 'brainstate.nn'
|
431
608
|
|
432
609
|
def __init__(
|
433
610
|
self,
|
@@ -449,7 +626,7 @@ class LecunUniform(VarianceScaling):
|
|
449
626
|
|
450
627
|
|
451
628
|
class LecunNormal(VarianceScaling):
|
452
|
-
__module__ = 'brainstate.
|
629
|
+
__module__ = 'brainstate.nn'
|
453
630
|
|
454
631
|
def __init__(
|
455
632
|
self,
|
@@ -477,7 +654,7 @@ class Orthogonal(Initializer):
|
|
477
654
|
If the shape is not square, the matrix will have orthonormal rows or columns
|
478
655
|
depending on which side is smaller.
|
479
656
|
"""
|
480
|
-
__module__ = 'brainstate.
|
657
|
+
__module__ = 'brainstate.nn'
|
481
658
|
|
482
659
|
def __init__(
|
483
660
|
self,
|
@@ -492,13 +669,14 @@ class Orthogonal(Initializer):
|
|
492
669
|
self.rng = random.default_rng(seed)
|
493
670
|
self.unit = unit
|
494
671
|
|
495
|
-
def __call__(self, shape,
|
496
|
-
dtype = dtype
|
672
|
+
def __call__(self, shape, **kwargs):
|
673
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
674
|
+
rng = kwargs.get('rng', self.rng)
|
497
675
|
shape = to_size(shape)
|
498
676
|
n_rows = shape[self.axis]
|
499
677
|
n_cols = np.prod(shape) // n_rows
|
500
678
|
matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
|
501
|
-
norm_dst =
|
679
|
+
norm_dst = rng.normal(size=matrix_shape, dtype=dtype)
|
502
680
|
|
503
681
|
q_mat, r_mat = jnp.linalg.qr(norm_dst)
|
504
682
|
# Enforce Q is uniformly distributed
|
@@ -517,7 +695,7 @@ class DeltaOrthogonal(Initializer):
|
|
517
695
|
|
518
696
|
The shape must be 3D, 4D or 5D.
|
519
697
|
"""
|
520
|
-
__module__ = 'brainstate.
|
698
|
+
__module__ = 'brainstate.nn'
|
521
699
|
|
522
700
|
def __init__(
|
523
701
|
self,
|
@@ -532,9 +710,9 @@ class DeltaOrthogonal(Initializer):
|
|
532
710
|
self.orghogonal = Orthogonal(scale=scale, axis=axis, seed=seed)
|
533
711
|
self.unit = unit
|
534
712
|
|
535
|
-
def __call__(self, shape,
|
713
|
+
def __call__(self, shape, **kwargs):
|
536
714
|
shape = to_size(shape)
|
537
|
-
dtype = dtype
|
715
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
538
716
|
if len(shape) not in [3, 4, 5]:
|
539
717
|
raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
|
540
718
|
if shape[-1] < shape[-2]:
|
@@ -551,3 +729,81 @@ class DeltaOrthogonal(Initializer):
|
|
551
729
|
k1, k2, k3 = shape[:3]
|
552
730
|
W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
|
553
731
|
return u.maybe_decimal(u.Quantity(W.mantissa, unit=self.unit))
|
732
|
+
|
733
|
+
|
734
|
+
class ZeroInit(Initializer):
|
735
|
+
"""Zero initializer.
|
736
|
+
|
737
|
+
Initialize the weights with zeros.
|
738
|
+
"""
|
739
|
+
__module__ = 'brainstate.nn'
|
740
|
+
|
741
|
+
def __init__(self, unit: u.Unit = u.UNITLESS):
|
742
|
+
super(ZeroInit, self).__init__()
|
743
|
+
self.unit = unit
|
744
|
+
|
745
|
+
def __call__(self, shape, **kwargs):
|
746
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
747
|
+
shape = to_size(shape)
|
748
|
+
return u.maybe_decimal(u.math.zeros(shape, dtype=dtype, unit=self.unit))
|
749
|
+
|
750
|
+
|
751
|
+
class Constant(Initializer):
|
752
|
+
"""Constant initializer.
|
753
|
+
|
754
|
+
Initialize the weights with the given values.
|
755
|
+
|
756
|
+
Parameters
|
757
|
+
----------
|
758
|
+
value : float, int, bm.ndarray
|
759
|
+
The value to specify.
|
760
|
+
"""
|
761
|
+
__module__ = 'brainstate.nn'
|
762
|
+
|
763
|
+
def __init__(self, value=1., ):
|
764
|
+
super(Constant, self).__init__()
|
765
|
+
self.value = value
|
766
|
+
|
767
|
+
def __call__(self, shape, **kwargs):
|
768
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
769
|
+
shape = to_size(shape)
|
770
|
+
return u.maybe_decimal(u.math.full(shape, self.value, dtype=dtype))
|
771
|
+
|
772
|
+
|
773
|
+
class Identity(Initializer):
|
774
|
+
"""Returns the identity matrix.
|
775
|
+
|
776
|
+
This initializer was proposed in (Le, et al., 2015) [1]_.
|
777
|
+
|
778
|
+
Parameters
|
779
|
+
----------
|
780
|
+
value : float
|
781
|
+
The optional scaling factor.
|
782
|
+
|
783
|
+
Returns
|
784
|
+
-------
|
785
|
+
shape: tuple of int
|
786
|
+
The weight shape/size.
|
787
|
+
|
788
|
+
References
|
789
|
+
----------
|
790
|
+
.. [1] Le, Quoc V., Navdeep Jaitly, and Geoffrey E. Hinton. "A simple way to
|
791
|
+
initialize recurrent networks of rectified linear units." arXiv preprint
|
792
|
+
arXiv:1504.00941 (2015).
|
793
|
+
"""
|
794
|
+
__module__ = 'brainstate.nn'
|
795
|
+
|
796
|
+
def __init__(self, value=1., unit: u.Unit = u.UNITLESS):
|
797
|
+
super(Identity, self).__init__()
|
798
|
+
self.value = value
|
799
|
+
self.unit = unit
|
800
|
+
|
801
|
+
def __call__(self, shape, **kwargs):
|
802
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
803
|
+
shape = to_size(shape)
|
804
|
+
if isinstance(shape, (tuple, list)):
|
805
|
+
if len(shape) > 2:
|
806
|
+
raise ValueError(f'Only support initialize 2D weights for {self.__class__.__name__}.')
|
807
|
+
r = u.math.eye(*shape, dtype=dtype)
|
808
|
+
r = u.math.fill_diagonal(r, self.value)
|
809
|
+
return u.maybe_decimal(u.Quantity(r, unit=self.unit))
|