brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
benchmark/COBA_2005.py ADDED
@@ -0,0 +1,125 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ #
17
+ # Implementation of the paper:
18
+ #
19
+ # - Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007),
20
+ # Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
21
+ #
22
+ # which is based on the balanced network proposed by:
23
+ #
24
+ # - Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
25
+ #
26
+ import os
27
+ import sys
28
+
29
+ sys.path.append('../')
30
+ os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.99'
31
+ os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
32
+
33
+
34
+ import jax
35
+ import brainunit as u
36
+ import time
37
+ import brainstate as bst
38
+
39
+
40
+ class EINet(bst.nn.DynamicsGroup):
41
+ def __init__(self, scale):
42
+ super().__init__()
43
+ self.n_exc = int(3200 * scale)
44
+ self.n_inh = int(800 * scale)
45
+ self.num = self.n_exc + self.n_inh
46
+ self.N = bst.nn.LIFRef(self.num, V_rest=-60. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
47
+ tau=20. * u.ms, tau_ref=5. * u.ms,
48
+ V_initializer=bst.init.Normal(-55., 2., unit=u.mV))
49
+ self.E = bst.nn.AlignPostProj(
50
+ comm=bst.event.FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=0.6 * u.mS),
51
+ syn=bst.nn.Expon.desc(self.num, tau=5. * u.ms),
52
+ out=bst.nn.COBA.desc(E=0. * u.mV),
53
+ post=self.N
54
+ )
55
+ self.I = bst.nn.AlignPostProj(
56
+ comm=bst.event.FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=6.7 * u.mS),
57
+ syn=bst.nn.Expon.desc(self.num, tau=10. * u.ms),
58
+ out=bst.nn.COBA.desc(E=-80. * u.mV),
59
+ post=self.N
60
+ )
61
+
62
+ def init_state(self, *args, **kwargs):
63
+ self.rate = bst.ShortTermState(u.math.zeros(self.num))
64
+
65
+ def update(self, t, inp):
66
+ with bst.environ.context(t=t):
67
+ spk = self.N.get_spike() != 0.
68
+ self.E(spk[:self.n_exc])
69
+ self.I(spk[self.n_exc:])
70
+ self.N(inp)
71
+ self.rate.value += self.N.get_spike()
72
+
73
+
74
+ @bst.compile.jit(static_argnums=0)
75
+ def run(scale: float):
76
+ # network
77
+ net = EINet(scale)
78
+ bst.nn.init_all_states(net)
79
+
80
+ duration = 1e4 * u.ms
81
+ # simulation
82
+ with bst.environ.context(dt=0.1 * u.ms):
83
+ times = u.math.arange(0. * u.ms, duration, bst.environ.get_dt())
84
+ bst.compile.for_loop(lambda t: net.update(t, 20. * u.mA), times)
85
+
86
+ return net.num, net.rate.value.sum() / net.num / duration.to_decimal(u.second)
87
+
88
+
89
+ for s in [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
90
+ jax.block_until_ready(run(s))
91
+
92
+ t0 = time.time()
93
+ n, rate = jax.block_until_ready(run(s))
94
+ t1 = time.time()
95
+ print(f'scale={s}, size={n}, time = {t1 - t0} s, firing rate = {rate} Hz')
96
+
97
+
98
+ # A6000 NVIDIA GPU
99
+
100
+ # scale=1, size=4000, time = 2.659956455230713 s, firing rate = 50.62445068359375 Hz
101
+ # scale=2, size=8000, time = 2.7318649291992188 s, firing rate = 50.613040924072266 Hz
102
+ # scale=4, size=16000, time = 2.807222604751587 s, firing rate = 50.60573959350586 Hz
103
+ # scale=6, size=24000, time = 3.026782512664795 s, firing rate = 50.60918045043945 Hz
104
+ # scale=8, size=32000, time = 3.1258811950683594 s, firing rate = 50.607574462890625 Hz
105
+ # scale=10, size=40000, time = 3.172346353530884 s, firing rate = 50.60942840576172 Hz
106
+ # scale=20, size=80000, time = 3.751189947128296 s, firing rate = 50.612369537353516 Hz
107
+ # scale=40, size=160000, time = 5.0217814445495605 s, firing rate = 50.617958068847656 Hz
108
+ # scale=60, size=240000, time = 7.002646207809448 s, firing rate = 50.61948776245117 Hz
109
+ # scale=80, size=320000, time = 9.384576320648193 s, firing rate = 50.618499755859375 Hz
110
+ # scale=100, size=400000, time = 11.69654369354248 s, firing rate = 50.61605453491211 Hz
111
+
112
+
113
+ # AMD Ryzen 7 7840HS
114
+
115
+ # scale=1, size=4000, time = 4.436027526855469 s, firing rate = 50.6119270324707 Hz
116
+ # scale=2, size=8000, time = 8.349745273590088 s, firing rate = 50.612266540527344 Hz
117
+ # scale=4, size=16000, time = 16.39163303375244 s, firing rate = 50.61349105834961 Hz
118
+ # scale=6, size=24000, time = 15.725558042526245 s, firing rate = 50.6125602722168 Hz
119
+ # scale=8, size=32000, time = 21.31995177268982 s, firing rate = 50.61244583129883 Hz
120
+ # scale=10, size=40000, time = 27.811061143875122 s, firing rate = 50.61423873901367 Hz
121
+ # scale=20, size=80000, time = 45.54235219955444 s, firing rate = 50.61320877075195 Hz
122
+ # scale=40, size=160000, time = 82.22228026390076 s, firing rate = 50.61309814453125 Hz
123
+ # scale=60, size=240000, time = 125.44037556648254 s, firing rate = 50.613094329833984 Hz
124
+ # scale=80, size=320000, time = 171.20458459854126 s, firing rate = 50.613365173339844 Hz
125
+ # scale=100, size=400000, time = 215.4547393321991 s, firing rate = 50.6129150390625 Hz
benchmark/CUBA_2005.py ADDED
@@ -0,0 +1,149 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ #
17
+ # Implementation of the paper:
18
+ #
19
+ # - Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007),
20
+ # Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
21
+ #
22
+ # which is based on the balanced network proposed by:
23
+ #
24
+ # - Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
25
+ #
26
+
27
+ import os
28
+ import sys
29
+
30
+ sys.path.append('../')
31
+ os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.99'
32
+ os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
33
+
34
+
35
+ import jax
36
+ import time
37
+
38
+ import brainunit as u
39
+
40
+ import brainstate as bst
41
+
42
+
43
+
44
+ class FixedProb(bst.nn.Module):
45
+ def __init__(self, n_pre, n_post, prob, weight):
46
+ super().__init__()
47
+ self.prob = prob
48
+ self.weight = weight
49
+ self.n_pre = n_pre
50
+ self.n_post = n_post
51
+
52
+ self.mask = bst.random.rand(n_pre, n_post) < prob
53
+
54
+ def update(self, x):
55
+ return (x @ self.mask) * self.weight
56
+
57
+
58
+ class EINet(bst.nn.DynamicsGroup):
59
+ def __init__(self, scale=1.0):
60
+ super().__init__()
61
+ self.n_exc = int(3200 * scale)
62
+ self.n_inh = int(800 * scale)
63
+ self.num = self.n_exc + self.n_inh
64
+ self.N = bst.nn.LIFRef(
65
+ self.num, V_rest=-49. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
66
+ tau=20. * u.ms, tau_ref=5. * u.ms,
67
+ V_initializer=bst.init.Normal(-55., 2., unit=u.mV)
68
+ )
69
+ self.E = bst.nn.AlignPostProj(
70
+ comm=bst.event.FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=1.62 * u.mS),
71
+ # comm=FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=1.62 * u.mS),
72
+ syn=bst.nn.Expon.desc(self.num, tau=5. * u.ms),
73
+ out=bst.nn.CUBA.desc(scale=u.volt),
74
+ post=self.N
75
+ )
76
+ self.I = bst.nn.AlignPostProj(
77
+ comm=bst.event.FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=-9.0 * u.mS),
78
+ # comm=FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=-9.0 * u.mS),
79
+ syn=bst.nn.Expon.desc(self.num, tau=10. * u.ms),
80
+ out=bst.nn.CUBA.desc(scale=u.volt),
81
+ post=self.N
82
+ )
83
+
84
+ def init_state(self, *args, **kwargs):
85
+ self.rate = bst.ShortTermState(u.math.zeros(self.num))
86
+
87
+ def update(self, t, inp):
88
+ with bst.environ.context(t=t):
89
+ spk = self.N.get_spike()
90
+ self.E(spk[:self.n_exc])
91
+ self.I(spk[self.n_exc:])
92
+ self.N(inp)
93
+ self.rate.value += self.N.get_spike()
94
+
95
+
96
+ @bst.compile.jit(static_argnums=0)
97
+ def run(scale: float):
98
+ # network
99
+ net = EINet(scale)
100
+ bst.nn.init_all_states(net)
101
+
102
+ duration = 1e4 * u.ms
103
+ # simulation
104
+ with bst.environ.context(dt=0.1 * u.ms):
105
+ times = u.math.arange(0. * u.ms, duration, bst.environ.get_dt())
106
+ bst.compile.for_loop(lambda t: net.update(t, 20. * u.mA), times,
107
+ # pbar=bst.compile.ProgressBar(100)
108
+ )
109
+
110
+ return net.num, net.rate.value.sum() / net.num / duration.to_decimal(u.second)
111
+
112
+
113
+ for s in [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
114
+ jax.block_until_ready(run(s))
115
+
116
+ t0 = time.time()
117
+ n, rate = jax.block_until_ready(run(s))
118
+ t1 = time.time()
119
+ print(f'scale={s}, size={n}, time = {t1 - t0} s, firing rate = {rate} Hz')
120
+
121
+
122
+ # A6000 NVIDIA GPU
123
+
124
+ # scale=1, size=4000, time = 2.6354849338531494 s, firing rate = 24.982027053833008 Hz
125
+ # scale=2, size=8000, time = 2.6781561374664307 s, firing rate = 23.719463348388672 Hz
126
+ # scale=4, size=16000, time = 2.7448785305023193 s, firing rate = 24.592931747436523 Hz
127
+ # scale=6, size=24000, time = 2.8237478733062744 s, firing rate = 24.159996032714844 Hz
128
+ # scale=8, size=32000, time = 2.9344418048858643 s, firing rate = 24.956790924072266 Hz
129
+ # scale=10, size=40000, time = 3.042517900466919 s, firing rate = 23.644424438476562 Hz
130
+ # scale=20, size=80000, time = 3.6727631092071533 s, firing rate = 24.226743698120117 Hz
131
+ # scale=40, size=160000, time = 4.857396602630615 s, firing rate = 24.329742431640625 Hz
132
+ # scale=60, size=240000, time = 6.812030792236328 s, firing rate = 24.370006561279297 Hz
133
+ # scale=80, size=320000, time = 9.227966547012329 s, firing rate = 24.41067886352539 Hz
134
+ # scale=100, size=400000, time = 11.405697584152222 s, firing rate = 24.32524871826172 Hz
135
+
136
+
137
+ # AMD Ryzen 7 7840HS
138
+
139
+ # scale=1, size=4000, time = 1.1661601066589355 s, firing rate = 22.438201904296875 Hz
140
+ # scale=2, size=8000, time = 3.3255884647369385 s, firing rate = 23.868364334106445 Hz
141
+ # scale=4, size=16000, time = 6.950139999389648 s, firing rate = 24.21693229675293 Hz
142
+ # scale=6, size=24000, time = 10.011993169784546 s, firing rate = 24.240270614624023 Hz
143
+ # scale=8, size=32000, time = 13.027734518051147 s, firing rate = 24.753198623657227 Hz
144
+ # scale=10, size=40000, time = 16.449942350387573 s, firing rate = 24.7176570892334 Hz
145
+ # scale=20, size=80000, time = 30.754598140716553 s, firing rate = 24.119956970214844 Hz
146
+ # scale=40, size=160000, time = 63.6387836933136 s, firing rate = 24.72784996032715 Hz
147
+ # scale=60, size=240000, time = 78.58532166481018 s, firing rate = 24.402742385864258 Hz
148
+ # scale=80, size=320000, time = 102.4250214099884 s, firing rate = 24.59092140197754 Hz
149
+ # scale=100, size=400000, time = 145.35173273086548 s, firing rate = 24.33751106262207 Hz
brainstate/__init__.py CHANGED
@@ -14,13 +14,17 @@
14
14
  # ==============================================================================
15
15
 
16
16
  """
17
- A ``State``-based Transformation System for Brain Dynamics Programming
17
+ A ``State``-based Transformation System for Program Compilation and Augmentation
18
18
  """
19
19
 
20
- __version__ = "0.0.2"
20
+ __version__ = "0.1.0"
21
21
 
22
+ from . import augment
23
+ from . import compile
22
24
  from . import environ
25
+ from . import event
23
26
  from . import functional
27
+ from . import graph
24
28
  from . import init
25
29
  from . import mixin
26
30
  from . import nn
@@ -30,17 +34,33 @@ from . import surrogate
30
34
  from . import transform
31
35
  from . import typing
32
36
  from . import util
33
- from ._visualization import *
34
- from ._visualization import __all__ as _visualization_all
35
- from ._module import *
36
- from ._module import __all__ as _module_all
37
37
  from ._state import *
38
38
  from ._state import __all__ as _state_all
39
39
 
40
40
  __all__ = (
41
- ['environ', 'share', 'nn', 'optim', 'random',
42
- 'surrogate', 'functional', 'init',
43
- 'mixin', 'transform', 'util', 'typing'] +
44
- _module_all + _state_all + _visualization_all
41
+ [
42
+ 'augment', 'compile', 'environ', 'event', 'functional',
43
+ 'graph', 'init', 'mixin', 'nn', 'optim', 'random',
44
+ 'surrogate', 'typing', 'util',
45
+ # deprecated
46
+ 'transform',
47
+ ] +
48
+ _state_all
45
49
  )
46
- del _module_all, _state_all, _visualization_all
50
+
51
+ # ----------------------- #
52
+ # deprecations
53
+ # ----------------------- #
54
+
55
+ from ._utils import deprecation_getattr
56
+
57
+ transform._deprecations = dict()
58
+ for key in compile.__all__:
59
+ transform._deprecations[key] = (f'brainstate.transform.{key}', f'brainstate.compile.{key}', getattr(compile, key))
60
+ for key in augment.__all__:
61
+ transform._deprecations[key] = (f'brainstate.transform.{key}', f'brainstate.augment.{key}', getattr(augment, key))
62
+ transform.__getattr__ = deprecation_getattr('brainstate.transform', transform._deprecations)
63
+ del deprecation_getattr
64
+
65
+ # ----------------------- #
66
+ del _state_all