brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1356 -1321
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -24,7 +24,7 @@ from brainstate._utils import set_module_as
24
24
  from brainstate.typing import ArrayLike
25
25
 
26
26
  __all__ = [
27
- 'weight_standardization',
27
+ 'weight_standardization',
28
28
  ]
29
29
 
30
30
 
@@ -35,49 +35,49 @@ def weight_standardization(
35
35
  gain: Optional[jax.Array] = None,
36
36
  out_axis: int = -1,
37
37
  ) -> Union[jax.Array, u.Quantity]:
38
- """
39
- Scaled Weight Standardization,
40
- see `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization <https://paperswithcode.com/paper/weight-standardization>`_.
38
+ """
39
+ Scaled Weight Standardization,
40
+ see `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization <https://paperswithcode.com/paper/weight-standardization>`_.
41
41
 
42
- Parameters
43
- ----------
44
- w : ArrayLike
45
- The weight tensor.
46
- eps : float
47
- A small value to avoid division by zero.
48
- gain : Array
49
- The gain function, by default None.
50
- out_axis : int
51
- The output axis, by default -1.
42
+ Parameters
43
+ ----------
44
+ w : ArrayLike
45
+ The weight tensor.
46
+ eps : float
47
+ A small value to avoid division by zero.
48
+ gain : Array
49
+ The gain function, by default None.
50
+ out_axis : int
51
+ The output axis, by default -1.
52
52
 
53
- Returns
54
- -------
55
- ArrayLike
56
- The scaled weight tensor.
57
- """
58
- if out_axis < 0:
59
- out_axis = w.ndim + out_axis
60
- fan_in = 1 # get the fan-in of the weight tensor
61
- axes = [] # get the axes of the weight tensor
62
- for i in range(w.ndim):
63
- if i != out_axis:
64
- fan_in *= w.shape[i]
65
- axes.append(i)
66
- # normalize the weight
67
- mean = u.math.mean(w, axis=axes, keepdims=True)
68
- var = u.math.var(w, axis=axes, keepdims=True)
53
+ Returns
54
+ -------
55
+ ArrayLike
56
+ The scaled weight tensor.
57
+ """
58
+ if out_axis < 0:
59
+ out_axis = w.ndim + out_axis
60
+ fan_in = 1 # get the fan-in of the weight tensor
61
+ axes = [] # get the axes of the weight tensor
62
+ for i in range(w.ndim):
63
+ if i != out_axis:
64
+ fan_in *= w.shape[i]
65
+ axes.append(i)
66
+ # normalize the weight
67
+ mean = u.math.mean(w, axis=axes, keepdims=True)
68
+ var = u.math.var(w, axis=axes, keepdims=True)
69
69
 
70
- temp = u.math.maximum(var * fan_in, eps)
71
- if isinstance(temp, u.Quantity):
72
- unit = temp.unit
73
- temp = temp.mantissa
74
- if unit.is_unitless:
75
- scale = jax.lax.rsqrt(temp)
70
+ temp = u.math.maximum(var * fan_in, eps)
71
+ if isinstance(temp, u.Quantity):
72
+ unit = temp.unit
73
+ temp = temp.mantissa
74
+ if unit.is_unitless:
75
+ scale = jax.lax.rsqrt(temp)
76
+ else:
77
+ scale = u.Quantity(jax.lax.rsqrt(temp), unit=1 / unit ** 0.5)
76
78
  else:
77
- scale = u.Quantity(jax.lax.rsqrt(temp), unit=1 / unit ** 0.5)
78
- else:
79
- scale = jax.lax.rsqrt(temp)
80
- if gain is not None:
81
- scale = gain * scale
82
- shift = mean * scale
83
- return w * scale - shift
79
+ scale = jax.lax.rsqrt(temp)
80
+ if gain is not None:
81
+ scale = gain * scale
82
+ shift = mean * scale
83
+ return w * scale - shift
@@ -23,7 +23,7 @@ import jax.numpy as jnp
23
23
  from brainstate.typing import PyTree
24
24
 
25
25
  __all__ = [
26
- 'clip_grad_norm',
26
+ 'clip_grad_norm',
27
27
  ]
28
28
 
29
29
 
@@ -32,17 +32,17 @@ def clip_grad_norm(
32
32
  max_norm: float | jax.Array,
33
33
  norm_type: int | str | None = None
34
34
  ):
35
- """
36
- Clips gradient norm of an iterable of parameters.
37
-
38
- The norm is computed over all gradients together, as if they were
39
- concatenated into a single vector. Gradients are modified in-place.
40
-
41
- Args:
42
- grad (PyTree): an iterable of Tensors or a single Tensor that will have gradients normalized
43
- max_norm (float): max norm of the gradients.
44
- norm_type (int, str, None): type of the used p-norm. Can be ``'inf'`` for infinity norm.
45
- """
46
- norm_fn = partial(jnp.linalg.norm, ord=norm_type)
47
- norm = norm_fn(jnp.asarray(jax.tree.leaves(jax.tree.map(norm_fn, grad))))
48
- return jax.tree.map(lambda x: jnp.where(norm < max_norm, x, x * max_norm / (norm + 1e-6)), grad)
35
+ """
36
+ Clips gradient norm of an iterable of parameters.
37
+
38
+ The norm is computed over all gradients together, as if they were
39
+ concatenated into a single vector. Gradients are modified in-place.
40
+
41
+ Args:
42
+ grad (PyTree): an iterable of Tensors or a single Tensor that will have gradients normalized
43
+ max_norm (float): max norm of the gradients.
44
+ norm_type (int, str, None): type of the used p-norm. Can be ``'inf'`` for infinity norm.
45
+ """
46
+ norm_fn = partial(jnp.linalg.norm, ord=norm_type)
47
+ norm = norm_fn(jnp.asarray(jax.tree.leaves(jax.tree.map(norm_fn, grad))))
48
+ return jax.tree.map(lambda x: jnp.where(norm < max_norm, x, x * max_norm / (norm + 1e-6)), grad)
@@ -16,74 +16,74 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  __all__ = [
19
- 'spike_bitwise_or',
20
- 'spike_bitwise_and',
21
- 'spike_bitwise_iand',
22
- 'spike_bitwise_not',
23
- 'spike_bitwise_xor',
24
- 'spike_bitwise_ixor',
25
- 'spike_bitwise',
19
+ 'spike_bitwise_or',
20
+ 'spike_bitwise_and',
21
+ 'spike_bitwise_iand',
22
+ 'spike_bitwise_not',
23
+ 'spike_bitwise_xor',
24
+ 'spike_bitwise_ixor',
25
+ 'spike_bitwise',
26
26
  ]
27
27
 
28
28
 
29
29
  def spike_bitwise_or(x, y):
30
- """Bitwise OR operation for spike tensors."""
31
- return x + y - x * y
30
+ """Bitwise OR operation for spike tensors."""
31
+ return x + y - x * y
32
32
 
33
33
 
34
34
  def spike_bitwise_and(x, y):
35
- """Bitwise AND operation for spike tensors."""
36
- return x * y
35
+ """Bitwise AND operation for spike tensors."""
36
+ return x * y
37
37
 
38
38
 
39
39
  def spike_bitwise_iand(x, y):
40
- """Bitwise IAND operation for spike tensors."""
41
- return (1 - x) * y
40
+ """Bitwise IAND operation for spike tensors."""
41
+ return (1 - x) * y
42
42
 
43
43
 
44
44
  def spike_bitwise_not(x):
45
- """Bitwise NOT operation for spike tensors."""
46
- return 1 - x
45
+ """Bitwise NOT operation for spike tensors."""
46
+ return 1 - x
47
47
 
48
48
 
49
49
  def spike_bitwise_xor(x, y):
50
- """Bitwise XOR operation for spike tensors."""
51
- return x + y - 2 * x * y
50
+ """Bitwise XOR operation for spike tensors."""
51
+ return x + y - 2 * x * y
52
52
 
53
53
 
54
54
  def spike_bitwise_ixor(x, y):
55
- """Bitwise IXOR operation for spike tensors."""
56
- return x * (1 - y) + (1 - x) * y
55
+ """Bitwise IXOR operation for spike tensors."""
56
+ return x * (1 - y) + (1 - x) * y
57
57
 
58
58
 
59
59
  def spike_bitwise(x, y, op: str):
60
- r"""Bitwise operation for spike tensors.
61
-
62
- .. math::
63
-
64
- \begin{array}{ccc}
65
- \hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
66
- \hline \text { ADD } & x+y & x+y \\
67
- \text { AND } & x \cap y & x \cdot y \\
68
- \text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
69
- \text { OR } & x \cup y & (x+y)-(x \cdot y) \\
70
- \hline
71
- \end{array}
72
-
73
- Args:
74
- x: A spike tensor.
75
- y: A spike tensor.
76
- op: A string indicating the bitwise operation to perform.
77
- """
78
- if op == 'or':
79
- return spike_bitwise_or(x, y)
80
- elif op == 'and':
81
- return spike_bitwise_and(x, y)
82
- elif op == 'iand':
83
- return spike_bitwise_iand(x, y)
84
- elif op == 'xor':
85
- return spike_bitwise_xor(x, y)
86
- elif op == 'ixor':
87
- return spike_bitwise_ixor(x, y)
88
- else:
89
- raise NotImplementedError(f"Unsupported bitwise operation: {op}.")
60
+ r"""Bitwise operation for spike tensors.
61
+
62
+ .. math::
63
+
64
+ \begin{array}{ccc}
65
+ \hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
66
+ \hline \text { ADD } & x+y & x+y \\
67
+ \text { AND } & x \cap y & x \cdot y \\
68
+ \text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
69
+ \text { OR } & x \cup y & (x+y)-(x \cdot y) \\
70
+ \hline
71
+ \end{array}
72
+
73
+ Args:
74
+ x: A spike tensor.
75
+ y: A spike tensor.
76
+ op: A string indicating the bitwise operation to perform.
77
+ """
78
+ if op == 'or':
79
+ return spike_bitwise_or(x, y)
80
+ elif op == 'and':
81
+ return spike_bitwise_and(x, y)
82
+ elif op == 'iand':
83
+ return spike_bitwise_iand(x, y)
84
+ elif op == 'xor':
85
+ return spike_bitwise_xor(x, y)
86
+ elif op == 'ixor':
87
+ return spike_bitwise_ixor(x, y)
88
+ else:
89
+ raise NotImplementedError(f"Unsupported bitwise operation: {op}.")
@@ -0,0 +1,33 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from ._graph_context import *
18
+ from ._graph_context import __all__ as _graph_context__all__
19
+ from ._graph_convert import *
20
+ from ._graph_convert import __all__ as _graph_convert__all__
21
+ from ._graph_node import *
22
+ from ._graph_node import __all__ as _graph_node__all__
23
+ from ._graph_operation import *
24
+ from ._graph_operation import __all__ as _graph_operation__all__
25
+
26
+ __all__ = (_graph_context__all__ +
27
+ _graph_convert__all__ +
28
+ _graph_node__all__ +
29
+ _graph_operation__all__)
30
+ del (_graph_context__all__,
31
+ _graph_convert__all__,
32
+ _graph_node__all__,
33
+ _graph_operation__all__)