brainstate 0.1.0__py2.py3-none-any.whl → 0.1.0.post20241125__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/_state.py +1 -1
  4. brainstate/augment/_autograd.py +121 -120
  5. brainstate/augment/_autograd_test.py +97 -0
  6. brainstate/event/__init__.py +10 -8
  7. brainstate/event/_csr_benchmark.py +14 -0
  8. brainstate/event/{_csr.py → _csr_mv.py} +26 -18
  9. brainstate/event/_csr_mv_benchmark.py +14 -0
  10. brainstate/event/_fixedprob_mv.py +708 -0
  11. brainstate/event/_fixedprob_mv_benchmark.py +128 -0
  12. brainstate/event/{_fixed_probability_test.py → _fixedprob_mv_test.py} +13 -10
  13. brainstate/event/_linear_mv.py +359 -0
  14. brainstate/event/_linear_mv_benckmark.py +82 -0
  15. brainstate/event/{_linear_test.py → _linear_mv_test.py} +9 -4
  16. brainstate/event/_xla_custom_op.py +309 -0
  17. brainstate/event/_xla_custom_op_test.py +55 -0
  18. brainstate/nn/_dyn_impl/_dynamics_synapse.py +6 -11
  19. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -1
  20. brainstate/nn/_dynamics/_projection_base.py +1 -1
  21. brainstate/nn/_exp_euler.py +1 -1
  22. brainstate/nn/_interaction/__init__.py +13 -4
  23. brainstate/nn/_interaction/{_connections.py → _conv.py} +0 -227
  24. brainstate/nn/_interaction/{_connections_test.py → _conv_test.py} +0 -15
  25. brainstate/nn/_interaction/_linear.py +582 -0
  26. brainstate/nn/_interaction/_linear_test.py +42 -0
  27. brainstate/optim/_lr_scheduler.py +1 -1
  28. brainstate/optim/_optax_optimizer.py +19 -0
  29. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/METADATA +2 -2
  30. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/RECORD +34 -24
  31. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/top_level.txt +1 -0
  32. brainstate/event/_fixed_probability.py +0 -271
  33. brainstate/event/_linear.py +0 -219
  34. /brainstate/event/{_csr_test.py → _csr_mv_test.py} +0 -0
  35. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/LICENSE +0 -0
  36. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/WHEEL +0 -0
benchmark/COBA_2005.py ADDED
@@ -0,0 +1,125 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ #
17
+ # Implementation of the paper:
18
+ #
19
+ # - Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007),
20
+ # Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
21
+ #
22
+ # which is based on the balanced network proposed by:
23
+ #
24
+ # - Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
25
+ #
26
+ import os
27
+ import sys
28
+
29
+ sys.path.append('../')
30
+ os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.99'
31
+ os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
32
+
33
+
34
+ import jax
35
+ import brainunit as u
36
+ import time
37
+ import brainstate as bst
38
+
39
+
40
+ class EINet(bst.nn.DynamicsGroup):
41
+ def __init__(self, scale):
42
+ super().__init__()
43
+ self.n_exc = int(3200 * scale)
44
+ self.n_inh = int(800 * scale)
45
+ self.num = self.n_exc + self.n_inh
46
+ self.N = bst.nn.LIFRef(self.num, V_rest=-60. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
47
+ tau=20. * u.ms, tau_ref=5. * u.ms,
48
+ V_initializer=bst.init.Normal(-55., 2., unit=u.mV))
49
+ self.E = bst.nn.AlignPostProj(
50
+ comm=bst.event.FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=0.6 * u.mS),
51
+ syn=bst.nn.Expon.desc(self.num, tau=5. * u.ms),
52
+ out=bst.nn.COBA.desc(E=0. * u.mV),
53
+ post=self.N
54
+ )
55
+ self.I = bst.nn.AlignPostProj(
56
+ comm=bst.event.FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=6.7 * u.mS),
57
+ syn=bst.nn.Expon.desc(self.num, tau=10. * u.ms),
58
+ out=bst.nn.COBA.desc(E=-80. * u.mV),
59
+ post=self.N
60
+ )
61
+
62
+ def init_state(self, *args, **kwargs):
63
+ self.rate = bst.ShortTermState(u.math.zeros(self.num))
64
+
65
+ def update(self, t, inp):
66
+ with bst.environ.context(t=t):
67
+ spk = self.N.get_spike() != 0.
68
+ self.E(spk[:self.n_exc])
69
+ self.I(spk[self.n_exc:])
70
+ self.N(inp)
71
+ self.rate.value += self.N.get_spike()
72
+
73
+
74
+ @bst.compile.jit(static_argnums=0)
75
+ def run(scale: float):
76
+ # network
77
+ net = EINet(scale)
78
+ bst.nn.init_all_states(net)
79
+
80
+ duration = 1e4 * u.ms
81
+ # simulation
82
+ with bst.environ.context(dt=0.1 * u.ms):
83
+ times = u.math.arange(0. * u.ms, duration, bst.environ.get_dt())
84
+ bst.compile.for_loop(lambda t: net.update(t, 20. * u.mA), times)
85
+
86
+ return net.num, net.rate.value.sum() / net.num / duration.to_decimal(u.second)
87
+
88
+
89
+ for s in [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
90
+ jax.block_until_ready(run(s))
91
+
92
+ t0 = time.time()
93
+ n, rate = jax.block_until_ready(run(s))
94
+ t1 = time.time()
95
+ print(f'scale={s}, size={n}, time = {t1 - t0} s, firing rate = {rate} Hz')
96
+
97
+
98
+ # A6000 NVIDIA GPU
99
+
100
+ # scale=1, size=4000, time = 2.659956455230713 s, firing rate = 50.62445068359375 Hz
101
+ # scale=2, size=8000, time = 2.7318649291992188 s, firing rate = 50.613040924072266 Hz
102
+ # scale=4, size=16000, time = 2.807222604751587 s, firing rate = 50.60573959350586 Hz
103
+ # scale=6, size=24000, time = 3.026782512664795 s, firing rate = 50.60918045043945 Hz
104
+ # scale=8, size=32000, time = 3.1258811950683594 s, firing rate = 50.607574462890625 Hz
105
+ # scale=10, size=40000, time = 3.172346353530884 s, firing rate = 50.60942840576172 Hz
106
+ # scale=20, size=80000, time = 3.751189947128296 s, firing rate = 50.612369537353516 Hz
107
+ # scale=40, size=160000, time = 5.0217814445495605 s, firing rate = 50.617958068847656 Hz
108
+ # scale=60, size=240000, time = 7.002646207809448 s, firing rate = 50.61948776245117 Hz
109
+ # scale=80, size=320000, time = 9.384576320648193 s, firing rate = 50.618499755859375 Hz
110
+ # scale=100, size=400000, time = 11.69654369354248 s, firing rate = 50.61605453491211 Hz
111
+
112
+
113
+ # AMD Ryzen 7 7840HS
114
+
115
+ # scale=1, size=4000, time = 4.436027526855469 s, firing rate = 50.6119270324707 Hz
116
+ # scale=2, size=8000, time = 8.349745273590088 s, firing rate = 50.612266540527344 Hz
117
+ # scale=4, size=16000, time = 16.39163303375244 s, firing rate = 50.61349105834961 Hz
118
+ # scale=6, size=24000, time = 15.725558042526245 s, firing rate = 50.6125602722168 Hz
119
+ # scale=8, size=32000, time = 21.31995177268982 s, firing rate = 50.61244583129883 Hz
120
+ # scale=10, size=40000, time = 27.811061143875122 s, firing rate = 50.61423873901367 Hz
121
+ # scale=20, size=80000, time = 45.54235219955444 s, firing rate = 50.61320877075195 Hz
122
+ # scale=40, size=160000, time = 82.22228026390076 s, firing rate = 50.61309814453125 Hz
123
+ # scale=60, size=240000, time = 125.44037556648254 s, firing rate = 50.613094329833984 Hz
124
+ # scale=80, size=320000, time = 171.20458459854126 s, firing rate = 50.613365173339844 Hz
125
+ # scale=100, size=400000, time = 215.4547393321991 s, firing rate = 50.6129150390625 Hz
benchmark/CUBA_2005.py ADDED
@@ -0,0 +1,149 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ #
17
+ # Implementation of the paper:
18
+ #
19
+ # - Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007),
20
+ # Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
21
+ #
22
+ # which is based on the balanced network proposed by:
23
+ #
24
+ # - Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
25
+ #
26
+
27
+ import os
28
+ import sys
29
+
30
+ sys.path.append('../')
31
+ os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.99'
32
+ os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
33
+
34
+
35
+ import jax
36
+ import time
37
+
38
+ import brainunit as u
39
+
40
+ import brainstate as bst
41
+
42
+
43
+
44
+ class FixedProb(bst.nn.Module):
45
+ def __init__(self, n_pre, n_post, prob, weight):
46
+ super().__init__()
47
+ self.prob = prob
48
+ self.weight = weight
49
+ self.n_pre = n_pre
50
+ self.n_post = n_post
51
+
52
+ self.mask = bst.random.rand(n_pre, n_post) < prob
53
+
54
+ def update(self, x):
55
+ return (x @ self.mask) * self.weight
56
+
57
+
58
+ class EINet(bst.nn.DynamicsGroup):
59
+ def __init__(self, scale=1.0):
60
+ super().__init__()
61
+ self.n_exc = int(3200 * scale)
62
+ self.n_inh = int(800 * scale)
63
+ self.num = self.n_exc + self.n_inh
64
+ self.N = bst.nn.LIFRef(
65
+ self.num, V_rest=-49. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
66
+ tau=20. * u.ms, tau_ref=5. * u.ms,
67
+ V_initializer=bst.init.Normal(-55., 2., unit=u.mV)
68
+ )
69
+ self.E = bst.nn.AlignPostProj(
70
+ comm=bst.event.FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=1.62 * u.mS),
71
+ # comm=FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=1.62 * u.mS),
72
+ syn=bst.nn.Expon.desc(self.num, tau=5. * u.ms),
73
+ out=bst.nn.CUBA.desc(scale=u.volt),
74
+ post=self.N
75
+ )
76
+ self.I = bst.nn.AlignPostProj(
77
+ comm=bst.event.FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=-9.0 * u.mS),
78
+ # comm=FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=-9.0 * u.mS),
79
+ syn=bst.nn.Expon.desc(self.num, tau=10. * u.ms),
80
+ out=bst.nn.CUBA.desc(scale=u.volt),
81
+ post=self.N
82
+ )
83
+
84
+ def init_state(self, *args, **kwargs):
85
+ self.rate = bst.ShortTermState(u.math.zeros(self.num))
86
+
87
+ def update(self, t, inp):
88
+ with bst.environ.context(t=t):
89
+ spk = self.N.get_spike()
90
+ self.E(spk[:self.n_exc])
91
+ self.I(spk[self.n_exc:])
92
+ self.N(inp)
93
+ self.rate.value += self.N.get_spike()
94
+
95
+
96
+ @bst.compile.jit(static_argnums=0)
97
+ def run(scale: float):
98
+ # network
99
+ net = EINet(scale)
100
+ bst.nn.init_all_states(net)
101
+
102
+ duration = 1e4 * u.ms
103
+ # simulation
104
+ with bst.environ.context(dt=0.1 * u.ms):
105
+ times = u.math.arange(0. * u.ms, duration, bst.environ.get_dt())
106
+ bst.compile.for_loop(lambda t: net.update(t, 20. * u.mA), times,
107
+ # pbar=bst.compile.ProgressBar(100)
108
+ )
109
+
110
+ return net.num, net.rate.value.sum() / net.num / duration.to_decimal(u.second)
111
+
112
+
113
+ for s in [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
114
+ jax.block_until_ready(run(s))
115
+
116
+ t0 = time.time()
117
+ n, rate = jax.block_until_ready(run(s))
118
+ t1 = time.time()
119
+ print(f'scale={s}, size={n}, time = {t1 - t0} s, firing rate = {rate} Hz')
120
+
121
+
122
+ # A6000 NVIDIA GPU
123
+
124
+ # scale=1, size=4000, time = 2.6354849338531494 s, firing rate = 24.982027053833008 Hz
125
+ # scale=2, size=8000, time = 2.6781561374664307 s, firing rate = 23.719463348388672 Hz
126
+ # scale=4, size=16000, time = 2.7448785305023193 s, firing rate = 24.592931747436523 Hz
127
+ # scale=6, size=24000, time = 2.8237478733062744 s, firing rate = 24.159996032714844 Hz
128
+ # scale=8, size=32000, time = 2.9344418048858643 s, firing rate = 24.956790924072266 Hz
129
+ # scale=10, size=40000, time = 3.042517900466919 s, firing rate = 23.644424438476562 Hz
130
+ # scale=20, size=80000, time = 3.6727631092071533 s, firing rate = 24.226743698120117 Hz
131
+ # scale=40, size=160000, time = 4.857396602630615 s, firing rate = 24.329742431640625 Hz
132
+ # scale=60, size=240000, time = 6.812030792236328 s, firing rate = 24.370006561279297 Hz
133
+ # scale=80, size=320000, time = 9.227966547012329 s, firing rate = 24.41067886352539 Hz
134
+ # scale=100, size=400000, time = 11.405697584152222 s, firing rate = 24.32524871826172 Hz
135
+
136
+
137
+ # AMD Ryzen 7 7840HS
138
+
139
+ # scale=1, size=4000, time = 1.1661601066589355 s, firing rate = 22.438201904296875 Hz
140
+ # scale=2, size=8000, time = 3.3255884647369385 s, firing rate = 23.868364334106445 Hz
141
+ # scale=4, size=16000, time = 6.950139999389648 s, firing rate = 24.21693229675293 Hz
142
+ # scale=6, size=24000, time = 10.011993169784546 s, firing rate = 24.240270614624023 Hz
143
+ # scale=8, size=32000, time = 13.027734518051147 s, firing rate = 24.753198623657227 Hz
144
+ # scale=10, size=40000, time = 16.449942350387573 s, firing rate = 24.7176570892334 Hz
145
+ # scale=20, size=80000, time = 30.754598140716553 s, firing rate = 24.119956970214844 Hz
146
+ # scale=40, size=160000, time = 63.6387836933136 s, firing rate = 24.72784996032715 Hz
147
+ # scale=60, size=240000, time = 78.58532166481018 s, firing rate = 24.402742385864258 Hz
148
+ # scale=80, size=320000, time = 102.4250214099884 s, firing rate = 24.59092140197754 Hz
149
+ # scale=100, size=400000, time = 145.35173273086548 s, firing rate = 24.33751106262207 Hz
brainstate/_state.py CHANGED
@@ -679,7 +679,7 @@ class StateTraceStack(Generic[A]):
679
679
  """
680
680
  for st, val in zip(self.states, self._original_state_values):
681
681
  # internal use
682
- st._value = val
682
+ st.restore_value(val)
683
683
 
684
684
  def merge(self, *traces) -> 'StateTraceStack':
685
685
  """