brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,231 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
-
4
- import jax
5
- import numpy as np
6
- from absl.testing import absltest
7
- from absl.testing import parameterized
8
-
9
- import brainstate as bst
10
- import brainstate.nn as nn
11
-
12
-
13
- class TestFlatten(parameterized.TestCase):
14
- def test_flatten1(self):
15
- for size in [
16
- (16, 32, 32, 8),
17
- (32, 8),
18
- (10, 20, 30),
19
- ]:
20
- arr = bst.random.rand(*size)
21
- f = nn.Flatten(start_axis=0)
22
- out = f(arr)
23
- self.assertTrue(out.shape == (np.prod(size),))
24
-
25
- def test_flatten2(self):
26
- for size in [
27
- (16, 32, 32, 8),
28
- (32, 8),
29
- (10, 20, 30),
30
- ]:
31
- arr = bst.random.rand(*size)
32
- f = nn.Flatten(start_axis=1)
33
- out = f(arr)
34
- self.assertTrue(out.shape == (size[0], np.prod(size[1:])))
35
-
36
- def test_flatten3(self):
37
- size = (16, 32, 32, 8)
38
- arr = bst.random.rand(*size)
39
- f = nn.Flatten(start_axis=0, in_size=(32, 8))
40
- out = f(arr)
41
- self.assertTrue(out.shape == (16, 32, 32 * 8))
42
-
43
- def test_flatten4(self):
44
- size = (16, 32, 32, 8)
45
- arr = bst.random.rand(*size)
46
- f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
47
- out = f(arr)
48
- self.assertTrue(out.shape == (16, 32, 32 * 8))
49
-
50
-
51
- class TestUnflatten(parameterized.TestCase):
52
- pass
53
-
54
-
55
- class TestPool(parameterized.TestCase):
56
- def __init__(self, *args, **kwargs):
57
- super().__init__(*args, **kwargs)
58
-
59
- def test_MaxPool2d_v1(self):
60
- arr = bst.random.rand(16, 32, 32, 8)
61
-
62
- out = nn.MaxPool2d(2, 2, channel_axis=-1)(arr)
63
- self.assertTrue(out.shape == (16, 16, 16, 8))
64
-
65
- out = nn.MaxPool2d(2, 2, channel_axis=None)(arr)
66
- self.assertTrue(out.shape == (16, 32, 16, 4))
67
-
68
- out = nn.MaxPool2d(2, 2, channel_axis=None, padding=1)(arr)
69
- self.assertTrue(out.shape == (16, 32, 17, 5))
70
-
71
- out = nn.MaxPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
72
- self.assertTrue(out.shape == (16, 32, 18, 5))
73
-
74
- out = nn.MaxPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
75
- self.assertTrue(out.shape == (16, 17, 17, 8))
76
-
77
- out = nn.MaxPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
78
- self.assertTrue(out.shape == (16, 17, 32, 5))
79
- bst.util.clear_buffer_memory()
80
-
81
- def test_AvgPool2d_v1(self):
82
- arr = bst.random.rand(16, 32, 32, 8)
83
-
84
- out = nn.AvgPool2d(2, 2, channel_axis=-1)(arr)
85
- self.assertTrue(out.shape == (16, 16, 16, 8))
86
-
87
- out = nn.AvgPool2d(2, 2, channel_axis=None)(arr)
88
- self.assertTrue(out.shape == (16, 32, 16, 4))
89
-
90
- out = nn.AvgPool2d(2, 2, channel_axis=None, padding=1)(arr)
91
- self.assertTrue(out.shape == (16, 32, 17, 5))
92
-
93
- out = nn.AvgPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
94
- self.assertTrue(out.shape == (16, 32, 18, 5))
95
-
96
- out = nn.AvgPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
97
- self.assertTrue(out.shape == (16, 17, 17, 8))
98
-
99
- out = nn.AvgPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
100
- self.assertTrue(out.shape == (16, 17, 32, 5))
101
- bst.util.clear_buffer_memory()
102
-
103
- @parameterized.named_parameters(
104
- dict(testcase_name=f'target_size={target_size}',
105
- target_size=target_size)
106
- for target_size in [10, 9, 8, 7, 6]
107
- )
108
- def test_adaptive_pool1d(self, target_size):
109
- from brainstate.nn._poolings import _adaptive_pool1d
110
-
111
- arr = bst.random.rand(100)
112
- op = jax.numpy.mean
113
-
114
- out = _adaptive_pool1d(arr, target_size, op)
115
- print(out.shape)
116
- self.assertTrue(out.shape == (target_size,))
117
-
118
- out = _adaptive_pool1d(arr, target_size, op)
119
- print(out.shape)
120
- self.assertTrue(out.shape == (target_size,))
121
- bst.util.clear_buffer_memory()
122
-
123
- def test_AdaptiveAvgPool2d_v1(self):
124
- input = bst.random.randn(64, 8, 9)
125
-
126
- output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
127
- self.assertTrue(output.shape == (64, 5, 7))
128
-
129
- output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input)
130
- self.assertTrue(output.shape == (64, 2, 3))
131
-
132
- output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input)
133
- self.assertTrue(output.shape == (2, 3, 9))
134
-
135
- output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input)
136
- self.assertTrue(output.shape == (2, 8, 3))
137
-
138
- output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=None)(input)
139
- self.assertTrue(output.shape == (64, 2, 3))
140
- bst.util.clear_buffer_memory()
141
-
142
- def test_AdaptiveAvgPool2d_v2(self):
143
- bst.random.seed()
144
- input = bst.random.randn(128, 64, 32, 16)
145
-
146
- output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
147
- self.assertTrue(output.shape == (128, 64, 5, 7))
148
-
149
- output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input)
150
- self.assertTrue(output.shape == (128, 64, 2, 3))
151
-
152
- output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input)
153
- self.assertTrue(output.shape == (128, 2, 3, 16))
154
-
155
- output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input)
156
- self.assertTrue(output.shape == (128, 64, 2, 3))
157
- print()
158
- bst.util.clear_buffer_memory()
159
-
160
- def test_AdaptiveAvgPool3d_v1(self):
161
- input = bst.random.randn(10, 128, 64, 32)
162
- net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3], channel_axis=0)
163
- output = net(input)
164
- self.assertTrue(output.shape == (10, 6, 5, 3))
165
- bst.util.clear_buffer_memory()
166
-
167
- def test_AdaptiveAvgPool3d_v2(self):
168
- input = bst.random.randn(10, 20, 128, 64, 32)
169
- net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3])
170
- output = net(input)
171
- self.assertTrue(output.shape == (10, 6, 5, 3, 32))
172
- bst.util.clear_buffer_memory()
173
-
174
- @parameterized.product(
175
- axis=(-1, 0, 1)
176
- )
177
- def test_AdaptiveMaxPool1d_v1(self, axis):
178
- input = bst.random.randn(32, 16)
179
- net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
180
- output = net(input)
181
- bst.util.clear_buffer_memory()
182
-
183
- @parameterized.product(
184
- axis=(-1, 0, 1, 2)
185
- )
186
- def test_AdaptiveMaxPool1d_v2(self, axis):
187
- input = bst.random.randn(2, 32, 16)
188
- net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
189
- output = net(input)
190
- bst.util.clear_buffer_memory()
191
-
192
- @parameterized.product(
193
- axis=(-1, 0, 1, 2)
194
- )
195
- def test_AdaptiveMaxPool2d_v1(self, axis):
196
- input = bst.random.randn(32, 16, 12)
197
- net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
198
- output = net(input)
199
- bst.util.clear_buffer_memory()
200
-
201
- @parameterized.product(
202
- axis=(-1, 0, 1, 2, 3)
203
- )
204
- def test_AdaptiveMaxPool2d_v2(self, axis):
205
- input = bst.random.randn(2, 32, 16, 12)
206
- net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
207
- output = net(input)
208
- bst.util.clear_buffer_memory()
209
-
210
- @parameterized.product(
211
- axis=(-1, 0, 1, 2, 3)
212
- )
213
- def test_AdaptiveMaxPool3d_v1(self, axis):
214
- input = bst.random.randn(2, 128, 64, 32)
215
- net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
216
- output = net(input)
217
- print()
218
- bst.util.clear_buffer_memory()
219
-
220
- @parameterized.product(
221
- axis=(-1, 0, 1, 2, 3, 4)
222
- )
223
- def test_AdaptiveMaxPool3d_v1(self, axis):
224
- input = bst.random.randn(2, 128, 64, 32, 16)
225
- net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
226
- output = net(input)
227
- bst.util.clear_buffer_memory()
228
-
229
-
230
- if __name__ == '__main__':
231
- absltest.main()