brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,239 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import jax.numpy as jnp
6
+ import pytest
7
+ from absl.testing import absltest
8
+ from absl.testing import parameterized
9
+
10
+ import brainstate as bst
11
+
12
+
13
+ class TestConv(parameterized.TestCase):
14
+ def test_Conv2D_img(self):
15
+ img = jnp.zeros((2, 200, 198, 4))
16
+ for k in range(4):
17
+ x = 30 + 60 * k
18
+ y = 20 + 60 * k
19
+ img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
20
+ img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)
21
+
22
+ net = bst.nn.Conv2d((200, 198, 4), out_channels=32, kernel_size=(3, 3),
23
+ stride=(2, 1), padding='VALID', groups=4)
24
+ out = net(img)
25
+ print("out shape: ", out.shape)
26
+ self.assertEqual(out.shape, (2, 99, 196, 32))
27
+ # print("First output channel:")
28
+ # plt.figure(figsize=(10, 10))
29
+ # plt.imshow(np.array(img)[0, :, :, 0])
30
+ # plt.show()
31
+
32
+ def test_conv1D(self):
33
+ model = bst.nn.Conv1d((5, 3), out_channels=32, kernel_size=(3,))
34
+ input = jnp.ones((2, 5, 3))
35
+ out = model(input)
36
+ print("out shape: ", out.shape)
37
+ self.assertEqual(out.shape, (2, 5, 32))
38
+ # print("First output channel:")
39
+ # plt.figure(figsize=(10, 10))
40
+ # plt.imshow(np.array(out)[0, :, :])
41
+ # plt.show()
42
+
43
+ def test_conv2D(self):
44
+ model = bst.nn.Conv2d((5, 5, 3), out_channels=32, kernel_size=(3, 3))
45
+ input = jnp.ones((2, 5, 5, 3))
46
+
47
+ out = model(input)
48
+ print("out shape: ", out.shape)
49
+ self.assertEqual(out.shape, (2, 5, 5, 32))
50
+
51
+ def test_conv3D(self):
52
+ model = bst.nn.Conv3d((5, 5, 5, 3), out_channels=32, kernel_size=(3, 3, 3))
53
+ input = jnp.ones((2, 5, 5, 5, 3))
54
+ out = model(input)
55
+ print("out shape: ", out.shape)
56
+ self.assertEqual(out.shape, (2, 5, 5, 5, 32))
57
+
58
+
59
+ @pytest.mark.skip(reason="not implemented yet")
60
+ class TestConvTranspose1d(parameterized.TestCase):
61
+ def test_conv_transpose(self):
62
+
63
+ x = jnp.ones((1, 8, 3))
64
+ for use_bias in [True, False]:
65
+ conv_transpose_module = bst.nn.ConvTranspose1d(
66
+ in_channels=3,
67
+ out_channels=4,
68
+ kernel_size=(3,),
69
+ padding='VALID',
70
+ w_initializer=bst.init.Constant(1.),
71
+ b_initializer=bst.init.Constant(1.) if use_bias else None,
72
+ )
73
+ self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
74
+ y = conv_transpose_module(x)
75
+ print(y.shape)
76
+ correct_ans = jnp.array([[[4., 4., 4., 4.],
77
+ [7., 7., 7., 7.],
78
+ [10., 10., 10., 10.],
79
+ [10., 10., 10., 10.],
80
+ [10., 10., 10., 10.],
81
+ [10., 10., 10., 10.],
82
+ [10., 10., 10., 10.],
83
+ [10., 10., 10., 10.],
84
+ [7., 7., 7., 7.],
85
+ [4., 4., 4., 4.]]])
86
+ if not use_bias:
87
+ correct_ans -= 1.
88
+ self.assertTrue(jnp.allclose(y, correct_ans))
89
+
90
+ def test_single_input_masked_conv_transpose(self):
91
+
92
+ x = jnp.ones((1, 8, 3))
93
+ m = jnp.tril(jnp.ones((3, 3, 4)))
94
+ conv_transpose_module = bst.nn.ConvTranspose1d(
95
+ in_channels=3,
96
+ out_channels=4,
97
+ kernel_size=(3,),
98
+ padding='VALID',
99
+ mask=m,
100
+ w_initializer=bst.init.Constant(),
101
+ b_initializer=bst.init.Constant(),
102
+ )
103
+ self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
104
+ y = conv_transpose_module(x)
105
+ print(y.shape)
106
+ correct_ans = jnp.array([[[4., 3., 2., 1.],
107
+ [7., 5., 3., 1.],
108
+ [10., 7., 4., 1.],
109
+ [10., 7., 4., 1.],
110
+ [10., 7., 4., 1.],
111
+ [10., 7., 4., 1.],
112
+ [10., 7., 4., 1.],
113
+ [10., 7., 4., 1.],
114
+ [7., 5., 3., 1.],
115
+ [4., 3., 2., 1.]]])
116
+ self.assertTrue(jnp.allclose(y, correct_ans))
117
+
118
+ def test_computation_padding_same(self):
119
+
120
+ data = jnp.ones([1, 3, 1])
121
+ for use_bias in [True, False]:
122
+ net = bst.nn.ConvTranspose1d(
123
+ in_channels=1,
124
+ out_channels=1,
125
+ kernel_size=3,
126
+ stride=1,
127
+ padding="SAME",
128
+ w_initializer=bst.init.Constant(),
129
+ b_initializer=bst.init.Constant() if use_bias else None,
130
+ )
131
+ out = net(data)
132
+ self.assertEqual(out.shape, (1, 3, 1))
133
+ out = jnp.squeeze(out, axis=(0, 2))
134
+ expected_out = jnp.asarray([2, 3, 2])
135
+ if use_bias:
136
+ expected_out += 1
137
+ self.assertTrue(jnp.allclose(out, expected_out, rtol=1e-5))
138
+
139
+
140
+ @pytest.mark.skip(reason="not implemented yet")
141
+ class TestConvTranspose2d(parameterized.TestCase):
142
+ def test_conv_transpose(self):
143
+
144
+ x = jnp.ones((1, 8, 8, 3))
145
+ for use_bias in [True, False]:
146
+ conv_transpose_module = bst.nn.ConvTranspose2d(
147
+ in_channels=3,
148
+ out_channels=4,
149
+ kernel_size=(3, 3),
150
+ padding='VALID',
151
+ w_initializer=bst.init.Constant(),
152
+ b_initializer=bst.init.Constant() if use_bias else None,
153
+ )
154
+ self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4))
155
+ y = conv_transpose_module(x)
156
+ print(y.shape)
157
+
158
+ def test_single_input_masked_conv_transpose(self):
159
+
160
+ x = jnp.ones((1, 8, 8, 3))
161
+ m = jnp.tril(jnp.ones((3, 3, 3, 4)))
162
+ conv_transpose_module = bst.nn.ConvTranspose2d(
163
+ in_channels=3,
164
+ out_channels=4,
165
+ kernel_size=(3, 3),
166
+ padding='VALID',
167
+ mask=m,
168
+ w_initializer=bst.init.Constant(),
169
+ )
170
+ y = conv_transpose_module(x)
171
+ print(y.shape)
172
+
173
+ def test_computation_padding_same(self):
174
+
175
+ x = jnp.ones((1, 8, 8, 3))
176
+ for use_bias in [True, False]:
177
+ conv_transpose_module = bst.nn.ConvTranspose2d(
178
+ in_channels=3,
179
+ out_channels=4,
180
+ kernel_size=(3, 3),
181
+ stride=1,
182
+ padding='SAME',
183
+ w_initializer=bst.init.Constant(),
184
+ b_initializer=bst.init.Constant() if use_bias else None,
185
+ )
186
+ y = conv_transpose_module(x)
187
+ print(y.shape)
188
+
189
+
190
+ @pytest.mark.skip(reason="not implemented yet")
191
+ class TestConvTranspose3d(parameterized.TestCase):
192
+ def test_conv_transpose(self):
193
+
194
+ x = jnp.ones((1, 8, 8, 8, 3))
195
+ for use_bias in [True, False]:
196
+ conv_transpose_module = bst.nn.ConvTranspose3d(
197
+ in_channels=3,
198
+ out_channels=4,
199
+ kernel_size=(3, 3, 3),
200
+ padding='VALID',
201
+ w_initializer=bst.init.Constant(),
202
+ b_initializer=bst.init.Constant() if use_bias else None,
203
+ )
204
+ y = conv_transpose_module(x)
205
+ print(y.shape)
206
+
207
+ def test_single_input_masked_conv_transpose(self):
208
+
209
+ x = jnp.ones((1, 8, 8, 8, 3))
210
+ m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
211
+ conv_transpose_module = bst.nn.ConvTranspose3d(
212
+ in_channels=3,
213
+ out_channels=4,
214
+ kernel_size=(3, 3, 3),
215
+ padding='VALID',
216
+ mask=m,
217
+ w_initializer=bst.init.Constant(),
218
+ )
219
+ y = conv_transpose_module(x)
220
+ print(y.shape)
221
+
222
+ def test_computation_padding_same(self):
223
+
224
+ x = jnp.ones((1, 8, 8, 8, 3))
225
+ for use_bias in [True, False]:
226
+ conv_transpose_module = bst.nn.ConvTranspose3d(
227
+ in_channels=3,
228
+ out_channels=4,
229
+ kernel_size=(3, 3, 3),
230
+ stride=1,
231
+ padding='SAME',
232
+ w_initializer=bst.init.Constant(),
233
+ b_initializer=bst.init.Constant() if use_bias else None,
234
+ )
235
+ y = conv_transpose_module(x)
236
+ print(y.shape)
237
+
238
+ if __name__ == '__main__':
239
+ absltest.main()
@@ -0,0 +1,59 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from __future__ import annotations
16
+
17
+ from typing import Optional, Callable, Union
18
+
19
+ from brainstate import init
20
+ from brainstate._state import ParamState
21
+ from brainstate.nn._module import Module
22
+ from brainstate.typing import ArrayLike
23
+
24
+ __all__ = [
25
+ 'Embedding',
26
+ ]
27
+
28
+
29
+ class Embedding(Module):
30
+ r"""
31
+ A simple lookup table that stores embeddings of a fixed size.
32
+
33
+ Args:
34
+ num_embeddings: Size of embedding dictionary. Must be non-negative.
35
+ embedding_size: Size of each embedding vector. Must be non-negative.
36
+ embedding_init: The initializer for the embedding lookup table, of shape `(num_embeddings, embedding_size)`.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ num_embeddings: int,
42
+ embedding_size: int,
43
+ embedding_init: Union[Callable, ArrayLike] = init.LecunUniform(),
44
+ name: Optional[str] = None,
45
+ ):
46
+ super().__init__(name=name)
47
+ if num_embeddings < 0:
48
+ raise ValueError("num_embeddings must not be negative.")
49
+ if embedding_size < 0:
50
+ raise ValueError("embedding_size must not be negative.")
51
+ self.num_embeddings = num_embeddings
52
+ self.embedding_size = embedding_size
53
+ self.out_size = (embedding_size,)
54
+
55
+ weight = init.param(embedding_init, (self.num_embeddings, self.embedding_size))
56
+ self.weight = ParamState(weight)
57
+
58
+ def update(self, indices: ArrayLike):
59
+ return self.weight.value[indices]