brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
brainstate/init/_base.py CHANGED
@@ -13,24 +13,42 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from __future__ import annotations
16
17
 
17
18
  from typing import Optional, Tuple
18
19
 
19
20
  import numpy as np
20
21
 
22
+ from brainstate.util import PrettyRepr, PrettyType, PrettyAttr
23
+
21
24
  __all__ = ['Initializer', 'to_size']
22
25
 
23
26
 
24
- class Initializer(object):
25
- def __call__(self, *args, **kwargs):
26
- raise NotImplementedError
27
+ class Initializer(PrettyRepr):
28
+ """
29
+ Base class for initializers.
30
+ """
31
+ __module__ = 'brainstate.init'
32
+
33
+ def __call__(self, *args, **kwargs):
34
+ raise NotImplementedError
35
+
36
+ def __pretty_repr__(self):
37
+ """
38
+ Pretty repr for the object.
39
+ """
40
+ yield PrettyType(type=type(self))
41
+ for name, value in vars(self).items():
42
+ if name.startswith('_'):
43
+ continue
44
+ yield PrettyAttr(name, repr(value))
27
45
 
28
46
 
29
47
  def to_size(x) -> Optional[Tuple[int]]:
30
- if isinstance(x, (tuple, list)):
31
- return tuple(x)
32
- if isinstance(x, (int, np.integer)):
33
- return (x,)
34
- if x is None:
35
- return x
36
- raise ValueError(f'Cannot make a size for {x}')
48
+ if isinstance(x, (tuple, list)):
49
+ return tuple(x)
50
+ if isinstance(x, (int, np.integer)):
51
+ return (x,)
52
+ if x is None:
53
+ return x
54
+ raise ValueError(f'Cannot make a size for {x}')
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
+ from __future__ import annotations
17
18
 
18
19
  from typing import Union, Callable, Optional, Sequence
19
20
 
@@ -22,221 +23,223 @@ import jax
22
23
  import numpy as np
23
24
 
24
25
  from brainstate._state import State
26
+ from brainstate._utils import set_module_as
25
27
  from brainstate.typing import ArrayLike
26
28
  from ._base import to_size
27
- from brainstate.mixin import Mode
28
29
 
29
30
  __all__ = [
30
- 'param',
31
- 'state',
32
- 'noise',
31
+ 'param',
32
+ 'state',
33
+ 'noise',
33
34
  ]
34
35
 
35
36
 
36
37
  def _is_scalar(x):
37
- return bu.math.isscalar(x)
38
+ return bu.math.isscalar(x)
38
39
 
39
40
 
40
41
  def are_broadcastable_shapes(shape1, shape2):
41
- """
42
- Check if two shapes are broadcastable.
42
+ """
43
+ Check if two shapes are broadcastable.
43
44
 
44
- Parameters:
45
- - shape1: Tuple[int], the shape of the first array.
46
- - shape2: Tuple[int], the shape of the second array.
45
+ Parameters:
46
+ - shape1: Tuple[int], the shape of the first array.
47
+ - shape2: Tuple[int], the shape of the second array.
47
48
 
48
- Returns:
49
- - bool: True if shapes are broadcastable, False otherwise.
50
- """
51
- # Reverse the shapes to compare from the last dimension
52
- shape1_reversed = shape1[::-1]
53
- shape2_reversed = shape2[::-1]
49
+ Returns:
50
+ - bool: True if shapes are broadcastable, False otherwise.
51
+ """
52
+ # Reverse the shapes to compare from the last dimension
53
+ shape1_reversed = shape1[::-1]
54
+ shape2_reversed = shape2[::-1]
54
55
 
55
- # Iterate over the dimensions of the shorter shape
56
- for dim1, dim2 in zip(shape1_reversed, shape2_reversed):
57
- # Check if the dimensions are not equal and neither is 1
58
- if dim1 != dim2 and 1 not in (dim1, dim2):
59
- return False
56
+ # Iterate over the dimensions of the shorter shape
57
+ for dim1, dim2 in zip(shape1_reversed, shape2_reversed):
58
+ # Check if the dimensions are not equal and neither is 1
59
+ if dim1 != dim2 and 1 not in (dim1, dim2):
60
+ return False
60
61
 
61
- # If all dimensions are compatible, the shapes are broadcastable
62
- return True
62
+ # If all dimensions are compatible, the shapes are broadcastable
63
+ return True
63
64
 
64
65
 
65
66
  def _expand_params_to_match_sizes(params, sizes):
66
- """
67
- Expand the dimensions of params to match the dimensions of sizes.
67
+ """
68
+ Expand the dimensions of params to match the dimensions of sizes.
68
69
 
69
- Parameters:
70
- - params: jax.Array or np.ndarray, the parameter array to be expanded.
71
- - sizes: tuple[int] or list[int], the target shape dimensions.
70
+ Parameters:
71
+ - params: jax.Array or np.ndarray, the parameter array to be expanded.
72
+ - sizes: tuple[int] or list[int], the target shape dimensions.
72
73
 
73
- Returns:
74
- - Expanded params with dimensions matching sizes.
75
- """
76
- params_dim = params.ndim
77
- sizes_dim = len(sizes)
78
- dim_diff = sizes_dim - params_dim
74
+ Returns:
75
+ - Expanded params with dimensions matching sizes.
76
+ """
77
+ params_dim = params.ndim
78
+ sizes_dim = len(sizes)
79
+ dim_diff = sizes_dim - params_dim
79
80
 
80
- # Add new axes to params if it has fewer dimensions than sizes
81
- for _ in range(dim_diff):
82
- params = bu.math.expand_dims(params, axis=0) # Add new axis at the last dimension
83
- return params
81
+ # Add new axes to params if it has fewer dimensions than sizes
82
+ for _ in range(dim_diff):
83
+ params = bu.math.expand_dims(params, axis=0) # Add new axis at the last dimension
84
+ return params
84
85
 
85
86
 
87
+ @set_module_as('brainstate.init')
86
88
  def param(
87
89
  parameter: Union[Callable, ArrayLike, State],
88
90
  sizes: Union[int, Sequence[int]],
89
91
  batch_size: Optional[int] = None,
90
92
  allow_none: bool = True,
91
93
  allow_scalar: bool = True,
92
- mode: Mode = None,
93
94
  ):
94
- """Initialize parameters.
95
-
96
- Parameters
97
- ----------
98
- parameter: callable, ArrayLike, State
99
- The initialization of the parameter.
100
- - If it is None, the created parameter will be None.
101
- - If it is a callable function :math:`f`, the ``f(size)`` will be returned.
102
- - If it is an instance of :py:class:`init.Initializer``, the ``f(size)`` will be returned.
103
- - If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``.
104
- sizes: int, sequence of int
105
- The shape of the parameter.
106
- batch_size: int
107
- The batch size.
108
- allow_none: bool
109
- Whether allow the parameter is None.
110
- allow_scalar: bool
111
- Whether allow the parameter is a scalar value.
112
-
113
- Returns
114
- -------
115
- param: ArrayType, float, int, bool, None
116
- The initialized parameter.
117
-
118
- See Also
119
- --------
120
- noise, state
121
- """
122
- # Check if the parameter is None
123
- if parameter is None:
124
- if allow_none:
125
- return None
95
+ """Initialize parameters.
96
+
97
+ Parameters
98
+ ----------
99
+ parameter: callable, ArrayLike, State
100
+ The initialization of the parameter.
101
+ - If it is None, the created parameter will be None.
102
+ - If it is a callable function :math:`f`, the ``f(size)`` will be returned.
103
+ - If it is an instance of :py:class:`init.Initializer``, the ``f(size)`` will be returned.
104
+ - If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``.
105
+ sizes: int, sequence of int
106
+ The shape of the parameter.
107
+ batch_size: int
108
+ The batch size.
109
+ allow_none: bool
110
+ Whether allow the parameter is None.
111
+ allow_scalar: bool
112
+ Whether allow the parameter is a scalar value.
113
+
114
+ Returns
115
+ -------
116
+ param: ArrayType, float, int, bool, None
117
+ The initialized parameter.
118
+
119
+ See Also
120
+ --------
121
+ noise, state
122
+ """
123
+ # Check if the parameter is None
124
+ if parameter is None:
125
+ if allow_none:
126
+ return None
127
+ else:
128
+ raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or '
129
+ f'Callable function, but we got None. ')
130
+
131
+ # Check if the parameter is a scalar value
132
+ if allow_scalar and _is_scalar(parameter):
133
+ return parameter
134
+
135
+ # Convert sizes to a tuple
136
+ sizes = tuple(to_size(sizes))
137
+
138
+ # Check if the parameter is a callable function
139
+ if callable(parameter):
140
+ if batch_size is not None:
141
+ sizes = (batch_size,) + sizes
142
+ return parameter(sizes)
143
+ elif isinstance(parameter, (np.ndarray, jax.Array, bu.Quantity, State)):
144
+ parameter = parameter
126
145
  else:
127
- raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or '
128
- f'Callable function, but we got None. ')
146
+ raise ValueError(f'Unknown parameter type: {type(parameter)}')
129
147
 
130
- # Check if the parameter is a scalar value
131
- if allow_scalar and _is_scalar(parameter):
132
- return parameter
148
+ # Check if the shape of the parameter matches the given size
149
+ if not are_broadcastable_shapes(parameter.shape, sizes):
150
+ raise ValueError(f'The shape of the parameter {parameter.shape} does not match with the given size {sizes}')
133
151
 
134
- # Convert sizes to a tuple
135
- sizes = tuple(to_size(sizes))
136
-
137
- # Check if the parameter is a callable function
138
- if callable(parameter):
152
+ # Expand the parameter to match the given batch size
153
+ param_value = parameter.value if isinstance(parameter, State) else parameter
139
154
  if batch_size is not None:
140
- sizes = (batch_size,) + sizes
141
- return parameter(sizes)
142
- elif isinstance(parameter, (np.ndarray, jax.Array, bu.Quantity, State)):
143
- parameter = parameter
144
- else:
145
- raise ValueError(f'Unknown parameter type: {type(parameter)}')
146
-
147
- # Check if the shape of the parameter matches the given size
148
- if not are_broadcastable_shapes(parameter.shape, sizes):
149
- raise ValueError(f'The shape of the parameter {parameter.shape} does not match with the given size {sizes}')
150
-
151
- # Expand the parameter to match the given batch size
152
- param_value = parameter.value if isinstance(parameter, State) else parameter
153
- if batch_size is not None:
154
- if param_value.ndim <= len(sizes):
155
- # add a new axis to the params so that it matches the dimensionality of the given shape ``sizes``
156
- param_value = _expand_params_to_match_sizes(param_value, sizes)
157
- param_value = bu.math.repeat(
158
- bu.math.expand_dims(param_value, axis=0),
159
- batch_size,
160
- axis=0
161
- )
162
- else:
163
- if param_value.shape[0] != batch_size:
164
- raise ValueError(f'The batch size of the parameter {param_value.shape[0]} '
165
- f'does not match with the given batch size {batch_size}')
166
- return type(parameter)(param_value) if isinstance(parameter, State) else param_value
167
-
168
-
155
+ if param_value.ndim <= len(sizes):
156
+ # add a new axis to the params so that it matches the dimensionality of the given shape ``sizes``
157
+ param_value = _expand_params_to_match_sizes(param_value, sizes)
158
+ param_value = bu.math.repeat(
159
+ bu.math.expand_dims(param_value, axis=0),
160
+ batch_size,
161
+ axis=0
162
+ )
163
+ else:
164
+ if param_value.shape[0] != batch_size:
165
+ raise ValueError(f'The batch size of the parameter {param_value.shape[0]} '
166
+ f'does not match with the given batch size {batch_size}')
167
+ return type(parameter)(param_value) if isinstance(parameter, State) else param_value
168
+
169
+
170
+ @set_module_as('brainstate.init')
169
171
  def state(
170
172
  init: Union[Callable, jax.typing.ArrayLike],
171
173
  sizes: Union[int, Sequence[int]] = None,
172
174
  batch_size: Optional[int] = None,
173
175
  ):
174
- """
175
- Initialize a :math:`~.State` from a callable function or a data.
176
- """
177
- sizes = to_size(sizes)
178
- if callable(init):
179
- if sizes is None:
180
- raise ValueError('"varshape" cannot be None when data is a callable function.')
181
- sizes = list(sizes)
182
- if isinstance(batch_size, int):
183
- sizes.insert(0, batch_size)
184
- return State(init(sizes))
185
-
186
- else:
187
- if sizes is not None:
188
- if bu.math.shape(init) != sizes:
189
- raise ValueError(f'The shape of "data" {bu.math.shape(init)} does not match with "var_shape" {sizes}')
190
- if isinstance(batch_size, int):
191
- batch_size = batch_size
192
- data = State(
193
- bu.math.repeat(
194
- bu.math.expand_dims(init, axis=0),
195
- batch_size,
196
- axis=0
197
- )
198
- )
199
- else:
200
- data = State(init)
201
- return data
202
-
176
+ """
177
+ Initialize a :math:`~.State` from a callable function or a data.
178
+ """
179
+ sizes = to_size(sizes)
180
+ if callable(init):
181
+ if sizes is None:
182
+ raise ValueError('"varshape" cannot be None when data is a callable function.')
183
+ sizes = list(sizes)
184
+ if isinstance(batch_size, int):
185
+ sizes.insert(0, batch_size)
186
+ return State(init(sizes))
203
187
 
188
+ else:
189
+ if sizes is not None:
190
+ if bu.math.shape(init) != sizes:
191
+ raise ValueError(f'The shape of "data" {bu.math.shape(init)} does not match with "var_shape" {sizes}')
192
+ if isinstance(batch_size, int):
193
+ batch_size = batch_size
194
+ data = State(
195
+ bu.math.repeat(
196
+ bu.math.expand_dims(init, axis=0),
197
+ batch_size,
198
+ axis=0
199
+ )
200
+ )
201
+ else:
202
+ data = State(init)
203
+ return data
204
+
205
+
206
+ @set_module_as('brainstate.init')
204
207
  def noise(
205
208
  noises: Optional[Union[ArrayLike, Callable]],
206
209
  size: Union[int, Sequence[int]],
207
210
  num_vars: int = 1,
208
211
  noise_idx: int = 0,
209
212
  ) -> Optional[Callable]:
210
- """Initialize a noise function.
211
-
212
- Parameters
213
- ----------
214
- noises: Any
215
- size: Shape
216
- The size of the noise.
217
- num_vars: int
218
- The number of variables.
219
- noise_idx: int
220
- The index of the current noise among all noise variables.
221
-
222
- Returns
223
- -------
224
- noise_func: function, None
225
- The noise function.
226
-
227
- See Also
228
- --------
229
- variable_, parameter, delay
230
-
231
- """
232
- if callable(noises):
233
- return noises
234
- elif noises is None:
235
- return None
236
- else:
237
- noises = param(noises, size, allow_none=False)
238
- if num_vars > 1:
239
- noises_ = [None] * num_vars
240
- noises_[noise_idx] = noises
241
- noises = tuple(noises_)
242
- return lambda *args, **kwargs: noises
213
+ """Initialize a noise function.
214
+
215
+ Parameters
216
+ ----------
217
+ noises: Any
218
+ size: Shape
219
+ The size of the noise.
220
+ num_vars: int
221
+ The number of variables.
222
+ noise_idx: int
223
+ The index of the current noise among all noise variables.
224
+
225
+ Returns
226
+ -------
227
+ noise_func: function, None
228
+ The noise function.
229
+
230
+ See Also
231
+ --------
232
+ variable_, parameter, delay
233
+
234
+ """
235
+ if callable(noises):
236
+ return noises
237
+ elif noises is None:
238
+ return None
239
+ else:
240
+ noises = param(noises, size, allow_none=False)
241
+ if num_vars > 1:
242
+ noises_ = [None] * num_vars
243
+ noises_[noise_idx] = noises
244
+ noises = tuple(noises_)
245
+ return lambda *args, **kwargs: noises