brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1356 -1321
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
brainstate/init/_base.py CHANGED
@@ -13,24 +13,42 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from __future__ import annotations
16
17
 
17
18
  from typing import Optional, Tuple
18
19
 
19
20
  import numpy as np
20
21
 
22
+ from brainstate.util import PrettyRepr, PrettyType, PrettyAttr
23
+
21
24
  __all__ = ['Initializer', 'to_size']
22
25
 
23
26
 
24
- class Initializer(object):
25
- def __call__(self, *args, **kwargs):
26
- raise NotImplementedError
27
+ class Initializer(PrettyRepr):
28
+ """
29
+ Base class for initializers.
30
+ """
31
+ __module__ = 'brainstate.init'
32
+
33
+ def __call__(self, *args, **kwargs):
34
+ raise NotImplementedError
35
+
36
+ def __pretty_repr__(self):
37
+ """
38
+ Pretty repr for the object.
39
+ """
40
+ yield PrettyType(type=type(self))
41
+ for name, value in vars(self).items():
42
+ if name.startswith('_'):
43
+ continue
44
+ yield PrettyAttr(name, repr(value))
27
45
 
28
46
 
29
47
  def to_size(x) -> Optional[Tuple[int]]:
30
- if isinstance(x, (tuple, list)):
31
- return tuple(x)
32
- if isinstance(x, (int, np.integer)):
33
- return (x,)
34
- if x is None:
35
- return x
36
- raise ValueError(f'Cannot make a size for {x}')
48
+ if isinstance(x, (tuple, list)):
49
+ return tuple(x)
50
+ if isinstance(x, (int, np.integer)):
51
+ return (x,)
52
+ if x is None:
53
+ return x
54
+ raise ValueError(f'Cannot make a size for {x}')
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
+ from __future__ import annotations
17
18
 
18
19
  from typing import Union, Callable, Optional, Sequence
19
20
 
@@ -22,221 +23,223 @@ import jax
22
23
  import numpy as np
23
24
 
24
25
  from brainstate._state import State
26
+ from brainstate._utils import set_module_as
25
27
  from brainstate.typing import ArrayLike
26
28
  from ._base import to_size
27
- from brainstate.mixin import Mode
28
29
 
29
30
  __all__ = [
30
- 'param',
31
- 'state',
32
- 'noise',
31
+ 'param',
32
+ 'state',
33
+ 'noise',
33
34
  ]
34
35
 
35
36
 
36
37
  def _is_scalar(x):
37
- return bu.math.isscalar(x)
38
+ return bu.math.isscalar(x)
38
39
 
39
40
 
40
41
  def are_broadcastable_shapes(shape1, shape2):
41
- """
42
- Check if two shapes are broadcastable.
42
+ """
43
+ Check if two shapes are broadcastable.
43
44
 
44
- Parameters:
45
- - shape1: Tuple[int], the shape of the first array.
46
- - shape2: Tuple[int], the shape of the second array.
45
+ Parameters:
46
+ - shape1: Tuple[int], the shape of the first array.
47
+ - shape2: Tuple[int], the shape of the second array.
47
48
 
48
- Returns:
49
- - bool: True if shapes are broadcastable, False otherwise.
50
- """
51
- # Reverse the shapes to compare from the last dimension
52
- shape1_reversed = shape1[::-1]
53
- shape2_reversed = shape2[::-1]
49
+ Returns:
50
+ - bool: True if shapes are broadcastable, False otherwise.
51
+ """
52
+ # Reverse the shapes to compare from the last dimension
53
+ shape1_reversed = shape1[::-1]
54
+ shape2_reversed = shape2[::-1]
54
55
 
55
- # Iterate over the dimensions of the shorter shape
56
- for dim1, dim2 in zip(shape1_reversed, shape2_reversed):
57
- # Check if the dimensions are not equal and neither is 1
58
- if dim1 != dim2 and 1 not in (dim1, dim2):
59
- return False
56
+ # Iterate over the dimensions of the shorter shape
57
+ for dim1, dim2 in zip(shape1_reversed, shape2_reversed):
58
+ # Check if the dimensions are not equal and neither is 1
59
+ if dim1 != dim2 and 1 not in (dim1, dim2):
60
+ return False
60
61
 
61
- # If all dimensions are compatible, the shapes are broadcastable
62
- return True
62
+ # If all dimensions are compatible, the shapes are broadcastable
63
+ return True
63
64
 
64
65
 
65
66
  def _expand_params_to_match_sizes(params, sizes):
66
- """
67
- Expand the dimensions of params to match the dimensions of sizes.
67
+ """
68
+ Expand the dimensions of params to match the dimensions of sizes.
68
69
 
69
- Parameters:
70
- - params: jax.Array or np.ndarray, the parameter array to be expanded.
71
- - sizes: tuple[int] or list[int], the target shape dimensions.
70
+ Parameters:
71
+ - params: jax.Array or np.ndarray, the parameter array to be expanded.
72
+ - sizes: tuple[int] or list[int], the target shape dimensions.
72
73
 
73
- Returns:
74
- - Expanded params with dimensions matching sizes.
75
- """
76
- params_dim = params.ndim
77
- sizes_dim = len(sizes)
78
- dim_diff = sizes_dim - params_dim
74
+ Returns:
75
+ - Expanded params with dimensions matching sizes.
76
+ """
77
+ params_dim = params.ndim
78
+ sizes_dim = len(sizes)
79
+ dim_diff = sizes_dim - params_dim
79
80
 
80
- # Add new axes to params if it has fewer dimensions than sizes
81
- for _ in range(dim_diff):
82
- params = bu.math.expand_dims(params, axis=0) # Add new axis at the last dimension
83
- return params
81
+ # Add new axes to params if it has fewer dimensions than sizes
82
+ for _ in range(dim_diff):
83
+ params = bu.math.expand_dims(params, axis=0) # Add new axis at the last dimension
84
+ return params
84
85
 
85
86
 
87
+ @set_module_as('brainstate.init')
86
88
  def param(
87
89
  parameter: Union[Callable, ArrayLike, State],
88
90
  sizes: Union[int, Sequence[int]],
89
91
  batch_size: Optional[int] = None,
90
92
  allow_none: bool = True,
91
93
  allow_scalar: bool = True,
92
- mode: Mode = None,
93
94
  ):
94
- """Initialize parameters.
95
-
96
- Parameters
97
- ----------
98
- parameter: callable, ArrayLike, State
99
- The initialization of the parameter.
100
- - If it is None, the created parameter will be None.
101
- - If it is a callable function :math:`f`, the ``f(size)`` will be returned.
102
- - If it is an instance of :py:class:`init.Initializer``, the ``f(size)`` will be returned.
103
- - If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``.
104
- sizes: int, sequence of int
105
- The shape of the parameter.
106
- batch_size: int
107
- The batch size.
108
- allow_none: bool
109
- Whether allow the parameter is None.
110
- allow_scalar: bool
111
- Whether allow the parameter is a scalar value.
112
-
113
- Returns
114
- -------
115
- param: ArrayType, float, int, bool, None
116
- The initialized parameter.
117
-
118
- See Also
119
- --------
120
- noise, state
121
- """
122
- # Check if the parameter is None
123
- if parameter is None:
124
- if allow_none:
125
- return None
95
+ """Initialize parameters.
96
+
97
+ Parameters
98
+ ----------
99
+ parameter: callable, ArrayLike, State
100
+ The initialization of the parameter.
101
+ - If it is None, the created parameter will be None.
102
+ - If it is a callable function :math:`f`, the ``f(size)`` will be returned.
103
+ - If it is an instance of :py:class:`init.Initializer``, the ``f(size)`` will be returned.
104
+ - If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``.
105
+ sizes: int, sequence of int
106
+ The shape of the parameter.
107
+ batch_size: int
108
+ The batch size.
109
+ allow_none: bool
110
+ Whether allow the parameter is None.
111
+ allow_scalar: bool
112
+ Whether allow the parameter is a scalar value.
113
+
114
+ Returns
115
+ -------
116
+ param: ArrayType, float, int, bool, None
117
+ The initialized parameter.
118
+
119
+ See Also
120
+ --------
121
+ noise, state
122
+ """
123
+ # Check if the parameter is None
124
+ if parameter is None:
125
+ if allow_none:
126
+ return None
127
+ else:
128
+ raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or '
129
+ f'Callable function, but we got None. ')
130
+
131
+ # Check if the parameter is a scalar value
132
+ if allow_scalar and _is_scalar(parameter):
133
+ return parameter
134
+
135
+ # Convert sizes to a tuple
136
+ sizes = tuple(to_size(sizes))
137
+
138
+ # Check if the parameter is a callable function
139
+ if callable(parameter):
140
+ if batch_size is not None:
141
+ sizes = (batch_size,) + sizes
142
+ return parameter(sizes)
143
+ elif isinstance(parameter, (np.ndarray, jax.Array, bu.Quantity, State)):
144
+ parameter = parameter
126
145
  else:
127
- raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or '
128
- f'Callable function, but we got None. ')
146
+ raise ValueError(f'Unknown parameter type: {type(parameter)}')
129
147
 
130
- # Check if the parameter is a scalar value
131
- if allow_scalar and _is_scalar(parameter):
132
- return parameter
148
+ # Check if the shape of the parameter matches the given size
149
+ if not are_broadcastable_shapes(parameter.shape, sizes):
150
+ raise ValueError(f'The shape of the parameter {parameter.shape} does not match with the given size {sizes}')
133
151
 
134
- # Convert sizes to a tuple
135
- sizes = tuple(to_size(sizes))
136
-
137
- # Check if the parameter is a callable function
138
- if callable(parameter):
152
+ # Expand the parameter to match the given batch size
153
+ param_value = parameter.value if isinstance(parameter, State) else parameter
139
154
  if batch_size is not None:
140
- sizes = (batch_size,) + sizes
141
- return parameter(sizes)
142
- elif isinstance(parameter, (np.ndarray, jax.Array, bu.Quantity, State)):
143
- parameter = parameter
144
- else:
145
- raise ValueError(f'Unknown parameter type: {type(parameter)}')
146
-
147
- # Check if the shape of the parameter matches the given size
148
- if not are_broadcastable_shapes(parameter.shape, sizes):
149
- raise ValueError(f'The shape of the parameter {parameter.shape} does not match with the given size {sizes}')
150
-
151
- # Expand the parameter to match the given batch size
152
- param_value = parameter.value if isinstance(parameter, State) else parameter
153
- if batch_size is not None:
154
- if param_value.ndim <= len(sizes):
155
- # add a new axis to the params so that it matches the dimensionality of the given shape ``sizes``
156
- param_value = _expand_params_to_match_sizes(param_value, sizes)
157
- param_value = bu.math.repeat(
158
- bu.math.expand_dims(param_value, axis=0),
159
- batch_size,
160
- axis=0
161
- )
162
- else:
163
- if param_value.shape[0] != batch_size:
164
- raise ValueError(f'The batch size of the parameter {param_value.shape[0]} '
165
- f'does not match with the given batch size {batch_size}')
166
- return type(parameter)(param_value) if isinstance(parameter, State) else param_value
167
-
168
-
155
+ if param_value.ndim <= len(sizes):
156
+ # add a new axis to the params so that it matches the dimensionality of the given shape ``sizes``
157
+ param_value = _expand_params_to_match_sizes(param_value, sizes)
158
+ param_value = bu.math.repeat(
159
+ bu.math.expand_dims(param_value, axis=0),
160
+ batch_size,
161
+ axis=0
162
+ )
163
+ else:
164
+ if param_value.shape[0] != batch_size:
165
+ raise ValueError(f'The batch size of the parameter {param_value.shape[0]} '
166
+ f'does not match with the given batch size {batch_size}')
167
+ return type(parameter)(param_value) if isinstance(parameter, State) else param_value
168
+
169
+
170
+ @set_module_as('brainstate.init')
169
171
  def state(
170
172
  init: Union[Callable, jax.typing.ArrayLike],
171
173
  sizes: Union[int, Sequence[int]] = None,
172
174
  batch_size: Optional[int] = None,
173
175
  ):
174
- """
175
- Initialize a :math:`~.State` from a callable function or a data.
176
- """
177
- sizes = to_size(sizes)
178
- if callable(init):
179
- if sizes is None:
180
- raise ValueError('"varshape" cannot be None when data is a callable function.')
181
- sizes = list(sizes)
182
- if isinstance(batch_size, int):
183
- sizes.insert(0, batch_size)
184
- return State(init(sizes))
185
-
186
- else:
187
- if sizes is not None:
188
- if bu.math.shape(init) != sizes:
189
- raise ValueError(f'The shape of "data" {bu.math.shape(init)} does not match with "var_shape" {sizes}')
190
- if isinstance(batch_size, int):
191
- batch_size = batch_size
192
- data = State(
193
- bu.math.repeat(
194
- bu.math.expand_dims(init, axis=0),
195
- batch_size,
196
- axis=0
197
- )
198
- )
199
- else:
200
- data = State(init)
201
- return data
202
-
176
+ """
177
+ Initialize a :math:`~.State` from a callable function or a data.
178
+ """
179
+ sizes = to_size(sizes)
180
+ if callable(init):
181
+ if sizes is None:
182
+ raise ValueError('"varshape" cannot be None when data is a callable function.')
183
+ sizes = list(sizes)
184
+ if isinstance(batch_size, int):
185
+ sizes.insert(0, batch_size)
186
+ return State(init(sizes))
203
187
 
188
+ else:
189
+ if sizes is not None:
190
+ if bu.math.shape(init) != sizes:
191
+ raise ValueError(f'The shape of "data" {bu.math.shape(init)} does not match with "var_shape" {sizes}')
192
+ if isinstance(batch_size, int):
193
+ batch_size = batch_size
194
+ data = State(
195
+ bu.math.repeat(
196
+ bu.math.expand_dims(init, axis=0),
197
+ batch_size,
198
+ axis=0
199
+ )
200
+ )
201
+ else:
202
+ data = State(init)
203
+ return data
204
+
205
+
206
+ @set_module_as('brainstate.init')
204
207
  def noise(
205
208
  noises: Optional[Union[ArrayLike, Callable]],
206
209
  size: Union[int, Sequence[int]],
207
210
  num_vars: int = 1,
208
211
  noise_idx: int = 0,
209
212
  ) -> Optional[Callable]:
210
- """Initialize a noise function.
211
-
212
- Parameters
213
- ----------
214
- noises: Any
215
- size: Shape
216
- The size of the noise.
217
- num_vars: int
218
- The number of variables.
219
- noise_idx: int
220
- The index of the current noise among all noise variables.
221
-
222
- Returns
223
- -------
224
- noise_func: function, None
225
- The noise function.
226
-
227
- See Also
228
- --------
229
- variable_, parameter, delay
230
-
231
- """
232
- if callable(noises):
233
- return noises
234
- elif noises is None:
235
- return None
236
- else:
237
- noises = param(noises, size, allow_none=False)
238
- if num_vars > 1:
239
- noises_ = [None] * num_vars
240
- noises_[noise_idx] = noises
241
- noises = tuple(noises_)
242
- return lambda *args, **kwargs: noises
213
+ """Initialize a noise function.
214
+
215
+ Parameters
216
+ ----------
217
+ noises: Any
218
+ size: Shape
219
+ The size of the noise.
220
+ num_vars: int
221
+ The number of variables.
222
+ noise_idx: int
223
+ The index of the current noise among all noise variables.
224
+
225
+ Returns
226
+ -------
227
+ noise_func: function, None
228
+ The noise function.
229
+
230
+ See Also
231
+ --------
232
+ variable_, parameter, delay
233
+
234
+ """
235
+ if callable(noises):
236
+ return noises
237
+ elif noises is None:
238
+ return None
239
+ else:
240
+ noises = param(noises, size, allow_none=False)
241
+ if num_vars > 1:
242
+ noises_ = [None] * num_vars
243
+ noises_[noise_idx] = noises
244
+ noises = tuple(noises_)
245
+ return lambda *args, **kwargs: noises