brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
brainstate/init/_base.py
CHANGED
@@ -13,24 +13,42 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
from __future__ import annotations
|
16
17
|
|
17
18
|
from typing import Optional, Tuple
|
18
19
|
|
19
20
|
import numpy as np
|
20
21
|
|
22
|
+
from brainstate.util import PrettyRepr, PrettyType, PrettyAttr
|
23
|
+
|
21
24
|
__all__ = ['Initializer', 'to_size']
|
22
25
|
|
23
26
|
|
24
|
-
class Initializer(
|
25
|
-
|
26
|
-
|
27
|
+
class Initializer(PrettyRepr):
|
28
|
+
"""
|
29
|
+
Base class for initializers.
|
30
|
+
"""
|
31
|
+
__module__ = 'brainstate.init'
|
32
|
+
|
33
|
+
def __call__(self, *args, **kwargs):
|
34
|
+
raise NotImplementedError
|
35
|
+
|
36
|
+
def __pretty_repr__(self):
|
37
|
+
"""
|
38
|
+
Pretty repr for the object.
|
39
|
+
"""
|
40
|
+
yield PrettyType(type=type(self))
|
41
|
+
for name, value in vars(self).items():
|
42
|
+
if name.startswith('_'):
|
43
|
+
continue
|
44
|
+
yield PrettyAttr(name, repr(value))
|
27
45
|
|
28
46
|
|
29
47
|
def to_size(x) -> Optional[Tuple[int]]:
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
48
|
+
if isinstance(x, (tuple, list)):
|
49
|
+
return tuple(x)
|
50
|
+
if isinstance(x, (int, np.integer)):
|
51
|
+
return (x,)
|
52
|
+
if x is None:
|
53
|
+
return x
|
54
|
+
raise ValueError(f'Cannot make a size for {x}')
|
brainstate/init/_generic.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
|
+
from __future__ import annotations
|
17
18
|
|
18
19
|
from typing import Union, Callable, Optional, Sequence
|
19
20
|
|
@@ -22,221 +23,223 @@ import jax
|
|
22
23
|
import numpy as np
|
23
24
|
|
24
25
|
from brainstate._state import State
|
26
|
+
from brainstate._utils import set_module_as
|
25
27
|
from brainstate.typing import ArrayLike
|
26
28
|
from ._base import to_size
|
27
|
-
from brainstate.mixin import Mode
|
28
29
|
|
29
30
|
__all__ = [
|
30
|
-
|
31
|
-
|
32
|
-
|
31
|
+
'param',
|
32
|
+
'state',
|
33
|
+
'noise',
|
33
34
|
]
|
34
35
|
|
35
36
|
|
36
37
|
def _is_scalar(x):
|
37
|
-
|
38
|
+
return bu.math.isscalar(x)
|
38
39
|
|
39
40
|
|
40
41
|
def are_broadcastable_shapes(shape1, shape2):
|
41
|
-
|
42
|
-
|
42
|
+
"""
|
43
|
+
Check if two shapes are broadcastable.
|
43
44
|
|
44
|
-
|
45
|
-
|
46
|
-
|
45
|
+
Parameters:
|
46
|
+
- shape1: Tuple[int], the shape of the first array.
|
47
|
+
- shape2: Tuple[int], the shape of the second array.
|
47
48
|
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
49
|
+
Returns:
|
50
|
+
- bool: True if shapes are broadcastable, False otherwise.
|
51
|
+
"""
|
52
|
+
# Reverse the shapes to compare from the last dimension
|
53
|
+
shape1_reversed = shape1[::-1]
|
54
|
+
shape2_reversed = shape2[::-1]
|
54
55
|
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
56
|
+
# Iterate over the dimensions of the shorter shape
|
57
|
+
for dim1, dim2 in zip(shape1_reversed, shape2_reversed):
|
58
|
+
# Check if the dimensions are not equal and neither is 1
|
59
|
+
if dim1 != dim2 and 1 not in (dim1, dim2):
|
60
|
+
return False
|
60
61
|
|
61
|
-
|
62
|
-
|
62
|
+
# If all dimensions are compatible, the shapes are broadcastable
|
63
|
+
return True
|
63
64
|
|
64
65
|
|
65
66
|
def _expand_params_to_match_sizes(params, sizes):
|
66
|
-
|
67
|
-
|
67
|
+
"""
|
68
|
+
Expand the dimensions of params to match the dimensions of sizes.
|
68
69
|
|
69
|
-
|
70
|
-
|
71
|
-
|
70
|
+
Parameters:
|
71
|
+
- params: jax.Array or np.ndarray, the parameter array to be expanded.
|
72
|
+
- sizes: tuple[int] or list[int], the target shape dimensions.
|
72
73
|
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
74
|
+
Returns:
|
75
|
+
- Expanded params with dimensions matching sizes.
|
76
|
+
"""
|
77
|
+
params_dim = params.ndim
|
78
|
+
sizes_dim = len(sizes)
|
79
|
+
dim_diff = sizes_dim - params_dim
|
79
80
|
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
81
|
+
# Add new axes to params if it has fewer dimensions than sizes
|
82
|
+
for _ in range(dim_diff):
|
83
|
+
params = bu.math.expand_dims(params, axis=0) # Add new axis at the last dimension
|
84
|
+
return params
|
84
85
|
|
85
86
|
|
87
|
+
@set_module_as('brainstate.init')
|
86
88
|
def param(
|
87
89
|
parameter: Union[Callable, ArrayLike, State],
|
88
90
|
sizes: Union[int, Sequence[int]],
|
89
91
|
batch_size: Optional[int] = None,
|
90
92
|
allow_none: bool = True,
|
91
93
|
allow_scalar: bool = True,
|
92
|
-
mode: Mode = None,
|
93
94
|
):
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
95
|
+
"""Initialize parameters.
|
96
|
+
|
97
|
+
Parameters
|
98
|
+
----------
|
99
|
+
parameter: callable, ArrayLike, State
|
100
|
+
The initialization of the parameter.
|
101
|
+
- If it is None, the created parameter will be None.
|
102
|
+
- If it is a callable function :math:`f`, the ``f(size)`` will be returned.
|
103
|
+
- If it is an instance of :py:class:`init.Initializer``, the ``f(size)`` will be returned.
|
104
|
+
- If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``.
|
105
|
+
sizes: int, sequence of int
|
106
|
+
The shape of the parameter.
|
107
|
+
batch_size: int
|
108
|
+
The batch size.
|
109
|
+
allow_none: bool
|
110
|
+
Whether allow the parameter is None.
|
111
|
+
allow_scalar: bool
|
112
|
+
Whether allow the parameter is a scalar value.
|
113
|
+
|
114
|
+
Returns
|
115
|
+
-------
|
116
|
+
param: ArrayType, float, int, bool, None
|
117
|
+
The initialized parameter.
|
118
|
+
|
119
|
+
See Also
|
120
|
+
--------
|
121
|
+
noise, state
|
122
|
+
"""
|
123
|
+
# Check if the parameter is None
|
124
|
+
if parameter is None:
|
125
|
+
if allow_none:
|
126
|
+
return None
|
127
|
+
else:
|
128
|
+
raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or '
|
129
|
+
f'Callable function, but we got None. ')
|
130
|
+
|
131
|
+
# Check if the parameter is a scalar value
|
132
|
+
if allow_scalar and _is_scalar(parameter):
|
133
|
+
return parameter
|
134
|
+
|
135
|
+
# Convert sizes to a tuple
|
136
|
+
sizes = tuple(to_size(sizes))
|
137
|
+
|
138
|
+
# Check if the parameter is a callable function
|
139
|
+
if callable(parameter):
|
140
|
+
if batch_size is not None:
|
141
|
+
sizes = (batch_size,) + sizes
|
142
|
+
return parameter(sizes)
|
143
|
+
elif isinstance(parameter, (np.ndarray, jax.Array, bu.Quantity, State)):
|
144
|
+
parameter = parameter
|
126
145
|
else:
|
127
|
-
|
128
|
-
f'Callable function, but we got None. ')
|
146
|
+
raise ValueError(f'Unknown parameter type: {type(parameter)}')
|
129
147
|
|
130
|
-
|
131
|
-
|
132
|
-
|
148
|
+
# Check if the shape of the parameter matches the given size
|
149
|
+
if not are_broadcastable_shapes(parameter.shape, sizes):
|
150
|
+
raise ValueError(f'The shape of the parameter {parameter.shape} does not match with the given size {sizes}')
|
133
151
|
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
# Check if the parameter is a callable function
|
138
|
-
if callable(parameter):
|
152
|
+
# Expand the parameter to match the given batch size
|
153
|
+
param_value = parameter.value if isinstance(parameter, State) else parameter
|
139
154
|
if batch_size is not None:
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
param_value = _expand_params_to_match_sizes(param_value, sizes)
|
157
|
-
param_value = bu.math.repeat(
|
158
|
-
bu.math.expand_dims(param_value, axis=0),
|
159
|
-
batch_size,
|
160
|
-
axis=0
|
161
|
-
)
|
162
|
-
else:
|
163
|
-
if param_value.shape[0] != batch_size:
|
164
|
-
raise ValueError(f'The batch size of the parameter {param_value.shape[0]} '
|
165
|
-
f'does not match with the given batch size {batch_size}')
|
166
|
-
return type(parameter)(param_value) if isinstance(parameter, State) else param_value
|
167
|
-
|
168
|
-
|
155
|
+
if param_value.ndim <= len(sizes):
|
156
|
+
# add a new axis to the params so that it matches the dimensionality of the given shape ``sizes``
|
157
|
+
param_value = _expand_params_to_match_sizes(param_value, sizes)
|
158
|
+
param_value = bu.math.repeat(
|
159
|
+
bu.math.expand_dims(param_value, axis=0),
|
160
|
+
batch_size,
|
161
|
+
axis=0
|
162
|
+
)
|
163
|
+
else:
|
164
|
+
if param_value.shape[0] != batch_size:
|
165
|
+
raise ValueError(f'The batch size of the parameter {param_value.shape[0]} '
|
166
|
+
f'does not match with the given batch size {batch_size}')
|
167
|
+
return type(parameter)(param_value) if isinstance(parameter, State) else param_value
|
168
|
+
|
169
|
+
|
170
|
+
@set_module_as('brainstate.init')
|
169
171
|
def state(
|
170
172
|
init: Union[Callable, jax.typing.ArrayLike],
|
171
173
|
sizes: Union[int, Sequence[int]] = None,
|
172
174
|
batch_size: Optional[int] = None,
|
173
175
|
):
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
else:
|
187
|
-
if sizes is not None:
|
188
|
-
if bu.math.shape(init) != sizes:
|
189
|
-
raise ValueError(f'The shape of "data" {bu.math.shape(init)} does not match with "var_shape" {sizes}')
|
190
|
-
if isinstance(batch_size, int):
|
191
|
-
batch_size = batch_size
|
192
|
-
data = State(
|
193
|
-
bu.math.repeat(
|
194
|
-
bu.math.expand_dims(init, axis=0),
|
195
|
-
batch_size,
|
196
|
-
axis=0
|
197
|
-
)
|
198
|
-
)
|
199
|
-
else:
|
200
|
-
data = State(init)
|
201
|
-
return data
|
202
|
-
|
176
|
+
"""
|
177
|
+
Initialize a :math:`~.State` from a callable function or a data.
|
178
|
+
"""
|
179
|
+
sizes = to_size(sizes)
|
180
|
+
if callable(init):
|
181
|
+
if sizes is None:
|
182
|
+
raise ValueError('"varshape" cannot be None when data is a callable function.')
|
183
|
+
sizes = list(sizes)
|
184
|
+
if isinstance(batch_size, int):
|
185
|
+
sizes.insert(0, batch_size)
|
186
|
+
return State(init(sizes))
|
203
187
|
|
188
|
+
else:
|
189
|
+
if sizes is not None:
|
190
|
+
if bu.math.shape(init) != sizes:
|
191
|
+
raise ValueError(f'The shape of "data" {bu.math.shape(init)} does not match with "var_shape" {sizes}')
|
192
|
+
if isinstance(batch_size, int):
|
193
|
+
batch_size = batch_size
|
194
|
+
data = State(
|
195
|
+
bu.math.repeat(
|
196
|
+
bu.math.expand_dims(init, axis=0),
|
197
|
+
batch_size,
|
198
|
+
axis=0
|
199
|
+
)
|
200
|
+
)
|
201
|
+
else:
|
202
|
+
data = State(init)
|
203
|
+
return data
|
204
|
+
|
205
|
+
|
206
|
+
@set_module_as('brainstate.init')
|
204
207
|
def noise(
|
205
208
|
noises: Optional[Union[ArrayLike, Callable]],
|
206
209
|
size: Union[int, Sequence[int]],
|
207
210
|
num_vars: int = 1,
|
208
211
|
noise_idx: int = 0,
|
209
212
|
) -> Optional[Callable]:
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
213
|
+
"""Initialize a noise function.
|
214
|
+
|
215
|
+
Parameters
|
216
|
+
----------
|
217
|
+
noises: Any
|
218
|
+
size: Shape
|
219
|
+
The size of the noise.
|
220
|
+
num_vars: int
|
221
|
+
The number of variables.
|
222
|
+
noise_idx: int
|
223
|
+
The index of the current noise among all noise variables.
|
224
|
+
|
225
|
+
Returns
|
226
|
+
-------
|
227
|
+
noise_func: function, None
|
228
|
+
The noise function.
|
229
|
+
|
230
|
+
See Also
|
231
|
+
--------
|
232
|
+
variable_, parameter, delay
|
233
|
+
|
234
|
+
"""
|
235
|
+
if callable(noises):
|
236
|
+
return noises
|
237
|
+
elif noises is None:
|
238
|
+
return None
|
239
|
+
else:
|
240
|
+
noises = param(noises, size, allow_none=False)
|
241
|
+
if num_vars > 1:
|
242
|
+
noises_ = [None] * num_vars
|
243
|
+
noises_[noise_idx] = noises
|
244
|
+
noises = tuple(noises_)
|
245
|
+
return lambda *args, **kwargs: noises
|