brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -24,7 +24,7 @@ from brainstate._utils import set_module_as
24
24
  from brainstate.typing import ArrayLike
25
25
 
26
26
  __all__ = [
27
- 'weight_standardization',
27
+ 'weight_standardization',
28
28
  ]
29
29
 
30
30
 
@@ -35,49 +35,49 @@ def weight_standardization(
35
35
  gain: Optional[jax.Array] = None,
36
36
  out_axis: int = -1,
37
37
  ) -> Union[jax.Array, u.Quantity]:
38
- """
39
- Scaled Weight Standardization,
40
- see `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization <https://paperswithcode.com/paper/weight-standardization>`_.
38
+ """
39
+ Scaled Weight Standardization,
40
+ see `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization <https://paperswithcode.com/paper/weight-standardization>`_.
41
41
 
42
- Parameters
43
- ----------
44
- w : ArrayLike
45
- The weight tensor.
46
- eps : float
47
- A small value to avoid division by zero.
48
- gain : Array
49
- The gain function, by default None.
50
- out_axis : int
51
- The output axis, by default -1.
42
+ Parameters
43
+ ----------
44
+ w : ArrayLike
45
+ The weight tensor.
46
+ eps : float
47
+ A small value to avoid division by zero.
48
+ gain : Array
49
+ The gain function, by default None.
50
+ out_axis : int
51
+ The output axis, by default -1.
52
52
 
53
- Returns
54
- -------
55
- ArrayLike
56
- The scaled weight tensor.
57
- """
58
- if out_axis < 0:
59
- out_axis = w.ndim + out_axis
60
- fan_in = 1 # get the fan-in of the weight tensor
61
- axes = [] # get the axes of the weight tensor
62
- for i in range(w.ndim):
63
- if i != out_axis:
64
- fan_in *= w.shape[i]
65
- axes.append(i)
66
- # normalize the weight
67
- mean = u.math.mean(w, axis=axes, keepdims=True)
68
- var = u.math.var(w, axis=axes, keepdims=True)
53
+ Returns
54
+ -------
55
+ ArrayLike
56
+ The scaled weight tensor.
57
+ """
58
+ if out_axis < 0:
59
+ out_axis = w.ndim + out_axis
60
+ fan_in = 1 # get the fan-in of the weight tensor
61
+ axes = [] # get the axes of the weight tensor
62
+ for i in range(w.ndim):
63
+ if i != out_axis:
64
+ fan_in *= w.shape[i]
65
+ axes.append(i)
66
+ # normalize the weight
67
+ mean = u.math.mean(w, axis=axes, keepdims=True)
68
+ var = u.math.var(w, axis=axes, keepdims=True)
69
69
 
70
- temp = u.math.maximum(var * fan_in, eps)
71
- if isinstance(temp, u.Quantity):
72
- unit = temp.unit
73
- temp = temp.mantissa
74
- if unit.is_unitless:
75
- scale = jax.lax.rsqrt(temp)
70
+ temp = u.math.maximum(var * fan_in, eps)
71
+ if isinstance(temp, u.Quantity):
72
+ unit = temp.unit
73
+ temp = temp.mantissa
74
+ if unit.is_unitless:
75
+ scale = jax.lax.rsqrt(temp)
76
+ else:
77
+ scale = u.Quantity(jax.lax.rsqrt(temp), unit=1 / unit ** 0.5)
76
78
  else:
77
- scale = u.Quantity(jax.lax.rsqrt(temp), unit=1 / unit ** 0.5)
78
- else:
79
- scale = jax.lax.rsqrt(temp)
80
- if gain is not None:
81
- scale = gain * scale
82
- shift = mean * scale
83
- return w * scale - shift
79
+ scale = jax.lax.rsqrt(temp)
80
+ if gain is not None:
81
+ scale = gain * scale
82
+ shift = mean * scale
83
+ return w * scale - shift
@@ -23,7 +23,7 @@ import jax.numpy as jnp
23
23
  from brainstate.typing import PyTree
24
24
 
25
25
  __all__ = [
26
- 'clip_grad_norm',
26
+ 'clip_grad_norm',
27
27
  ]
28
28
 
29
29
 
@@ -32,17 +32,17 @@ def clip_grad_norm(
32
32
  max_norm: float | jax.Array,
33
33
  norm_type: int | str | None = None
34
34
  ):
35
- """
36
- Clips gradient norm of an iterable of parameters.
37
-
38
- The norm is computed over all gradients together, as if they were
39
- concatenated into a single vector. Gradients are modified in-place.
40
-
41
- Args:
42
- grad (PyTree): an iterable of Tensors or a single Tensor that will have gradients normalized
43
- max_norm (float): max norm of the gradients.
44
- norm_type (int, str, None): type of the used p-norm. Can be ``'inf'`` for infinity norm.
45
- """
46
- norm_fn = partial(jnp.linalg.norm, ord=norm_type)
47
- norm = norm_fn(jnp.asarray(jax.tree.leaves(jax.tree.map(norm_fn, grad))))
48
- return jax.tree.map(lambda x: jnp.where(norm < max_norm, x, x * max_norm / (norm + 1e-6)), grad)
35
+ """
36
+ Clips gradient norm of an iterable of parameters.
37
+
38
+ The norm is computed over all gradients together, as if they were
39
+ concatenated into a single vector. Gradients are modified in-place.
40
+
41
+ Args:
42
+ grad (PyTree): an iterable of Tensors or a single Tensor that will have gradients normalized
43
+ max_norm (float): max norm of the gradients.
44
+ norm_type (int, str, None): type of the used p-norm. Can be ``'inf'`` for infinity norm.
45
+ """
46
+ norm_fn = partial(jnp.linalg.norm, ord=norm_type)
47
+ norm = norm_fn(jnp.asarray(jax.tree.leaves(jax.tree.map(norm_fn, grad))))
48
+ return jax.tree.map(lambda x: jnp.where(norm < max_norm, x, x * max_norm / (norm + 1e-6)), grad)
@@ -16,74 +16,74 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  __all__ = [
19
- 'spike_bitwise_or',
20
- 'spike_bitwise_and',
21
- 'spike_bitwise_iand',
22
- 'spike_bitwise_not',
23
- 'spike_bitwise_xor',
24
- 'spike_bitwise_ixor',
25
- 'spike_bitwise',
19
+ 'spike_bitwise_or',
20
+ 'spike_bitwise_and',
21
+ 'spike_bitwise_iand',
22
+ 'spike_bitwise_not',
23
+ 'spike_bitwise_xor',
24
+ 'spike_bitwise_ixor',
25
+ 'spike_bitwise',
26
26
  ]
27
27
 
28
28
 
29
29
  def spike_bitwise_or(x, y):
30
- """Bitwise OR operation for spike tensors."""
31
- return x + y - x * y
30
+ """Bitwise OR operation for spike tensors."""
31
+ return x + y - x * y
32
32
 
33
33
 
34
34
  def spike_bitwise_and(x, y):
35
- """Bitwise AND operation for spike tensors."""
36
- return x * y
35
+ """Bitwise AND operation for spike tensors."""
36
+ return x * y
37
37
 
38
38
 
39
39
  def spike_bitwise_iand(x, y):
40
- """Bitwise IAND operation for spike tensors."""
41
- return (1 - x) * y
40
+ """Bitwise IAND operation for spike tensors."""
41
+ return (1 - x) * y
42
42
 
43
43
 
44
44
  def spike_bitwise_not(x):
45
- """Bitwise NOT operation for spike tensors."""
46
- return 1 - x
45
+ """Bitwise NOT operation for spike tensors."""
46
+ return 1 - x
47
47
 
48
48
 
49
49
  def spike_bitwise_xor(x, y):
50
- """Bitwise XOR operation for spike tensors."""
51
- return x + y - 2 * x * y
50
+ """Bitwise XOR operation for spike tensors."""
51
+ return x + y - 2 * x * y
52
52
 
53
53
 
54
54
  def spike_bitwise_ixor(x, y):
55
- """Bitwise IXOR operation for spike tensors."""
56
- return x * (1 - y) + (1 - x) * y
55
+ """Bitwise IXOR operation for spike tensors."""
56
+ return x * (1 - y) + (1 - x) * y
57
57
 
58
58
 
59
59
  def spike_bitwise(x, y, op: str):
60
- r"""Bitwise operation for spike tensors.
61
-
62
- .. math::
63
-
64
- \begin{array}{ccc}
65
- \hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
66
- \hline \text { ADD } & x+y & x+y \\
67
- \text { AND } & x \cap y & x \cdot y \\
68
- \text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
69
- \text { OR } & x \cup y & (x+y)-(x \cdot y) \\
70
- \hline
71
- \end{array}
72
-
73
- Args:
74
- x: A spike tensor.
75
- y: A spike tensor.
76
- op: A string indicating the bitwise operation to perform.
77
- """
78
- if op == 'or':
79
- return spike_bitwise_or(x, y)
80
- elif op == 'and':
81
- return spike_bitwise_and(x, y)
82
- elif op == 'iand':
83
- return spike_bitwise_iand(x, y)
84
- elif op == 'xor':
85
- return spike_bitwise_xor(x, y)
86
- elif op == 'ixor':
87
- return spike_bitwise_ixor(x, y)
88
- else:
89
- raise NotImplementedError(f"Unsupported bitwise operation: {op}.")
60
+ r"""Bitwise operation for spike tensors.
61
+
62
+ .. math::
63
+
64
+ \begin{array}{ccc}
65
+ \hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
66
+ \hline \text { ADD } & x+y & x+y \\
67
+ \text { AND } & x \cap y & x \cdot y \\
68
+ \text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
69
+ \text { OR } & x \cup y & (x+y)-(x \cdot y) \\
70
+ \hline
71
+ \end{array}
72
+
73
+ Args:
74
+ x: A spike tensor.
75
+ y: A spike tensor.
76
+ op: A string indicating the bitwise operation to perform.
77
+ """
78
+ if op == 'or':
79
+ return spike_bitwise_or(x, y)
80
+ elif op == 'and':
81
+ return spike_bitwise_and(x, y)
82
+ elif op == 'iand':
83
+ return spike_bitwise_iand(x, y)
84
+ elif op == 'xor':
85
+ return spike_bitwise_xor(x, y)
86
+ elif op == 'ixor':
87
+ return spike_bitwise_ixor(x, y)
88
+ else:
89
+ raise NotImplementedError(f"Unsupported bitwise operation: {op}.")
@@ -0,0 +1,33 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from ._graph_context import *
18
+ from ._graph_context import __all__ as _graph_context__all__
19
+ from ._graph_convert import *
20
+ from ._graph_convert import __all__ as _graph_convert__all__
21
+ from ._graph_node import *
22
+ from ._graph_node import __all__ as _graph_node__all__
23
+ from ._graph_operation import *
24
+ from ._graph_operation import __all__ as _graph_operation__all__
25
+
26
+ __all__ = (_graph_context__all__ +
27
+ _graph_convert__all__ +
28
+ _graph_node__all__ +
29
+ _graph_operation__all__)
30
+ del (_graph_context__all__,
31
+ _graph_convert__all__,
32
+ _graph_node__all__,
33
+ _graph_operation__all__)