brainstate 0.1.0.post20241125__py2.py3-none-any.whl → 0.1.0.post20241209__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +15 -1
- brainstate/augment/_mapping_test.py +84 -0
- brainstate/compile/_loop_collect_return.py +5 -1
- brainstate/compile/_make_jaxpr.py +30 -25
- brainstate/compile/_progress_bar.py +30 -12
- brainstate/functional/_activations.py +4 -12
- brainstate/graph/_graph_operation.py +4 -1
- brainstate/nn/_collective_ops.py +18 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -1
- brainstate/nn/_elementwise/_dropout.py +31 -22
- brainstate/nn/_interaction/_normalizations.py +598 -66
- brainstate/util/_tracers.py +0 -7
- {brainstate-0.1.0.post20241125.dist-info → brainstate-0.1.0.post20241209.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20241125.dist-info → brainstate-0.1.0.post20241209.dist-info}/RECORD +17 -19
- {brainstate-0.1.0.post20241125.dist-info → brainstate-0.1.0.post20241209.dist-info}/top_level.txt +0 -1
- benchmark/COBA_2005.py +0 -125
- benchmark/CUBA_2005.py +0 -149
- {brainstate-0.1.0.post20241125.dist-info → brainstate-0.1.0.post20241209.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20241125.dist-info → brainstate-0.1.0.post20241209.dist-info}/WHEEL +0 -0
brainstate/util/_tracers.py
CHANGED
@@ -16,7 +16,6 @@ from __future__ import annotations
|
|
16
16
|
|
17
17
|
import jax
|
18
18
|
import jax.core
|
19
|
-
from jax.interpreters import partial_eval as pe
|
20
19
|
|
21
20
|
from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr
|
22
21
|
|
@@ -25,12 +24,6 @@ __all__ = [
|
|
25
24
|
]
|
26
25
|
|
27
26
|
|
28
|
-
def new_jax_trace():
|
29
|
-
main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
|
30
|
-
frame = main.jaxpr_stack[-1]
|
31
|
-
trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
|
32
|
-
return frame, trace
|
33
|
-
|
34
27
|
|
35
28
|
def current_jax_trace():
|
36
29
|
"""Returns the Jax tracing state."""
|
{brainstate-0.1.0.post20241125.dist-info → brainstate-0.1.0.post20241209.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20241209
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -1,7 +1,5 @@
|
|
1
|
-
benchmark/COBA_2005.py,sha256=Q8PsZ0lxu14jsF3bCtlZW35iQB8S2_oFEUYQzK2hPiA,5561
|
2
|
-
benchmark/CUBA_2005.py,sha256=_W94yOMh2ueqblU4ItEPeTLwHF0_lbEWlVNEBy0Tix0,6222
|
3
1
|
brainstate/__init__.py,sha256=r7C3eLTg8LEusoH6PGgBFFt4ZgbketYLoLA0lQhUCsE,2098
|
4
|
-
brainstate/_state.py,sha256=
|
2
|
+
brainstate/_state.py,sha256=4aDpLyHGr1VlPXeLSfM3USQG5K4o7orF7IlaBdYrtfE,29098
|
5
3
|
brainstate/_state_test.py,sha256=1boTp1w8DiCFLsPwNtlLrlIqGRpkasAmLid5bv2fgP4,2223
|
6
4
|
brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
|
7
5
|
brainstate/environ.py,sha256=G6r_rqfbofRbjFFalRu_DHaL7ruFTeLRXBQDXM6P-tQ,17477
|
@@ -17,7 +15,7 @@ brainstate/augment/_autograd_test.py,sha256=S2eEgrwTzdSi3u2nKE3u37WSThosLwx1WCP9
|
|
17
15
|
brainstate/augment/_eval_shape.py,sha256=dGlRVHOAZ9LSRZsFi1erxgEWHrnhBO3Kq3WW11-Hvng,3819
|
18
16
|
brainstate/augment/_eval_shape_test.py,sha256=1nnxbU7hPRbZPQWNWbQ518pw-H7FGDKKnQpZGBY9uRI,1390
|
19
17
|
brainstate/augment/_mapping.py,sha256=cpxzVGCEYnP5jPqrowYoPXciw_-QR2F3wggrRj1OCPc,21850
|
20
|
-
brainstate/augment/_mapping_test.py,sha256=
|
18
|
+
brainstate/augment/_mapping_test.py,sha256=TEAecjZmTSDCfxARgrzcDJ2dW1Yz_sCITmFiA9FGrhk,9455
|
21
19
|
brainstate/augment/_random.py,sha256=rkB4w4BkKsz9p8lTk31kVHvlVPJSvtGk8REn936KI_4,3071
|
22
20
|
brainstate/compile/__init__.py,sha256=qZZIYoyEl51IFkFu-Hb-bP3PAEHo94HlTDf57P2ze08,1858
|
23
21
|
brainstate/compile/_ad_checkpoint.py,sha256=5zJ1ENeTU4FzRY_uNpr85NhKfuicMMjcIbhu6-bSM4k,9451
|
@@ -28,13 +26,13 @@ brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvr
|
|
28
26
|
brainstate/compile/_error_if_test.py,sha256=SJmAfosVoGd4vhfFtb1IvjeFVW914bfTccCg6DoLWYk,1992
|
29
27
|
brainstate/compile/_jit.py,sha256=bfEszNttEtE6npqHBam1_DBlRa39fE6qP6lGaWw2amA,13750
|
30
28
|
brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
|
31
|
-
brainstate/compile/_loop_collect_return.py,sha256=
|
29
|
+
brainstate/compile/_loop_collect_return.py,sha256=8vDB2l0d4sIn0apspJzkhFhxjsL7reIptDeFRI9b1tc,23002
|
32
30
|
brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
|
33
31
|
brainstate/compile/_loop_no_collection.py,sha256=2rSK20enkBMXPAbsCyb7PCICPNrgaSpl5jfumgWpxA0,7401
|
34
32
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
35
|
-
brainstate/compile/_make_jaxpr.py,sha256=
|
33
|
+
brainstate/compile/_make_jaxpr.py,sha256=S5O9KUB3bsxoKcfptlV0MRfKA__Ija37WxkakIRL3z0,33010
|
36
34
|
brainstate/compile/_make_jaxpr_test.py,sha256=qJUtkyj50JQ6f4UJbOLhvRdkbNn3NSKibFL9jESdQkA,4279
|
37
|
-
brainstate/compile/_progress_bar.py,sha256=
|
35
|
+
brainstate/compile/_progress_bar.py,sha256=FafEbD9KzmhCCizfQoXXLw46asn9_uiuH1U5_DMtSXg,4529
|
38
36
|
brainstate/compile/_unvmap.py,sha256=ewbLLNXiI_dBsEBaVzSS0BEXNol22sd9gMzk606lSkM,4139
|
39
37
|
brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
|
40
38
|
brainstate/event/__init__.py,sha256=wOBkq7kDg90M8Y9FuoXRlSEuu1ZzbIhCJ1dHeLqN6_Q,1194
|
@@ -52,7 +50,7 @@ brainstate/event/_misc.py,sha256=8IpPooXjF2m0-tuo3pGHqThq2yLSNmYziy_zdurZ3NI,104
|
|
52
50
|
brainstate/event/_xla_custom_op.py,sha256=QB4jz_fUEPF-efJCVKAxwx8U79AqdcKoEg2QrGwot8I,10864
|
53
51
|
brainstate/event/_xla_custom_op_test.py,sha256=rnkGMleXzLfJj4y5QqwfBvCCLTAHe_uabwBDniY-URM,1745
|
54
52
|
brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
|
55
|
-
brainstate/functional/_activations.py,sha256=
|
53
|
+
brainstate/functional/_activations.py,sha256=S0Ok7sq5FTbmJWSejpOCHo1jpKX0gYOLy_TO2IUXM8s,21726
|
56
54
|
brainstate/functional/_activations_test.py,sha256=T___RlSrIfXwlkw8dg5A9EZMTZGDzv3a2evUwq_nYFg,13034
|
57
55
|
brainstate/functional/_normalization.py,sha256=i2EV7hSsqcNdcYRX2wAxjq8doHwyN9eNJTGTaPt03xE,2605
|
58
56
|
brainstate/functional/_others.py,sha256=_u_Ys-LiLzDAP4zJggVwaVvirgoS3jvhXMREoS6JOkM,1737
|
@@ -63,7 +61,7 @@ brainstate/graph/_graph_context_test.py,sha256=IYpjqbXwSFF65XL0ZbdPeC1jYyEHLpQVr
|
|
63
61
|
brainstate/graph/_graph_convert.py,sha256=llSREtGQrIggkD0wmxUbYKuSveLW4ihDZME6Ab-mRTQ,9147
|
64
62
|
brainstate/graph/_graph_node.py,sha256=BTuVlGgA2b82zNudjsN88QXuxfDcMvU2-kB64AkdQnY,8993
|
65
63
|
brainstate/graph/_graph_node_test.py,sha256=BFGfdzZFDHI0XK7hHotSVWKt3em1taGvn8FHF9NCXx8,2702
|
66
|
-
brainstate/graph/_graph_operation.py,sha256=
|
64
|
+
brainstate/graph/_graph_operation.py,sha256=PupZeFWBR-OHbhdJcoqlvy2YqoIS9Ze4q0tz8HRy4f4,64166
|
67
65
|
brainstate/graph/_graph_operation_test.py,sha256=ADyyuMk2xidEkkFNpGvUbvEtRmUj-tqOI4cF3eRuakM,24678
|
68
66
|
brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
|
69
67
|
brainstate/init/_base.py,sha256=B_NLS9aKNrvuj5NAlSgBbQTVev7IRvzcx8vH0J-Gq2w,1671
|
@@ -73,7 +71,7 @@ brainstate/init/_random_inits_test.py,sha256=lBL2RQdBSZ88Zqz4IMdbHJMvDi7ooZq6caC
|
|
73
71
|
brainstate/init/_regular_inits.py,sha256=DmVMajugfyYFNUMzgFdDKMvbBu9hMWxkfDd-50uhoLg,3187
|
74
72
|
brainstate/init/_regular_inits_test.py,sha256=tJl4aOkclllJIfKzJTbc0cfYCw2SoBsx8_G123RnqbU,1842
|
75
73
|
brainstate/nn/__init__.py,sha256=rxURT8J1XfBn3Vh3Dx_WzVADWn9zVriIty5KZEG-x6o,1622
|
76
|
-
brainstate/nn/_collective_ops.py,sha256=
|
74
|
+
brainstate/nn/_collective_ops.py,sha256=sSjIIs1MvZA30XFFmK7iL1D_sCeh7hFd3PanCH6kgZo,6779
|
77
75
|
brainstate/nn/_exp_euler.py,sha256=yjkfSllFxGWKEAlHo5AzBizzkFj6FEVDKmFV6E2g214,3521
|
78
76
|
brainstate/nn/_exp_euler_test.py,sha256=clwRD8QR71k1jn6NrACMDEUcFMh0J9RTosoPnlYWUkw,1242
|
79
77
|
brainstate/nn/_module.py,sha256=HDLPvLfB7jat2VT3gBu0MxA7vfzK7xgowemitHX8Cgo,10835
|
@@ -82,7 +80,7 @@ brainstate/nn/metrics.py,sha256=iupHjSRTHYY-HmEPBC4tXWrZfF4zh1ek2NwSAA0gnwE,1473
|
|
82
80
|
brainstate/nn/_dyn_impl/__init__.py,sha256=Oazar7h89dp1WA2Vx4Tj7gCBhxJKH4LAUEABkBEG7vU,1462
|
83
81
|
brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2K8k5FAcf3Pa5N8,10927
|
84
82
|
brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
|
85
|
-
brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=
|
83
|
+
brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=IHy6IsGjWpKZ8NLq4X7PaRwx3tpO2HRZNppCWM2fe4I,11862
|
86
84
|
brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
|
87
85
|
brainstate/nn/_dyn_impl/_inputs.py,sha256=6eZKnkmrM0Gog2fpSKjSnwnQvhbFYhG4q9Vuo-GH2LI,5050
|
88
86
|
brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
|
@@ -98,7 +96,7 @@ brainstate/nn/_dynamics/_state_delay.py,sha256=nZYGmVKmQvAQu-W4YOUFH1gnr-ZS3rg_G
|
|
98
96
|
brainstate/nn/_dynamics/_synouts.py,sha256=9TGAc-nVa50th7KKn4oKLbro-4W4rwxYvp-eu7ksAIE,4491
|
99
97
|
brainstate/nn/_dynamics/_synouts_test.py,sha256=V_jDswRN4VvEXD-2yJO3VA1TALgX0HK6oPBQiUntOWc,2266
|
100
98
|
brainstate/nn/_elementwise/__init__.py,sha256=PK8oq1K_EG2941AiUyLxCWoRdWvMO3yt8ZJbw3Lkhu8,935
|
101
|
-
brainstate/nn/_elementwise/_dropout.py,sha256=
|
99
|
+
brainstate/nn/_elementwise/_dropout.py,sha256=0Ebo-2y1VswvBqZ7sCA0SEUm37y49EUsef8oiSFpYGk,17759
|
102
100
|
brainstate/nn/_elementwise/_dropout_test.py,sha256=Qn7xqZOyZMPCGF6tFjTiPId0yELOXjSsW5-hgihP3fE,4383
|
103
101
|
brainstate/nn/_elementwise/_elementwise.py,sha256=om-KpwDTk5yFG5KBYXXHquRLV7s28_FJjk-omvyMyvQ,33342
|
104
102
|
brainstate/nn/_elementwise/_elementwise_test.py,sha256=SZI9jB39sZ5SO1dpWGW-PhodthwN0GU9FY1nqf2fWcs,5341
|
@@ -108,7 +106,7 @@ brainstate/nn/_interaction/_conv_test.py,sha256=fHXRFYnDghFiKre63RqMwIE_gbPKdK34
|
|
108
106
|
brainstate/nn/_interaction/_embedding.py,sha256=iK0I1ExKWFa_QzV9UDGj32Ljsmdr1g_LlAtMcusebxU,2187
|
109
107
|
brainstate/nn/_interaction/_linear.py,sha256=bjiWGJCe81ugQQOykwjWlLW5uhe0CHWwkPA20a4n5YQ,21340
|
110
108
|
brainstate/nn/_interaction/_linear_test.py,sha256=KlvFZA0rpyaspf4LT4K7u-RR5jCEB_q1WReqAw9sFcU,1274
|
111
|
-
brainstate/nn/_interaction/_normalizations.py,sha256=
|
109
|
+
brainstate/nn/_interaction/_normalizations.py,sha256=7YDzkmO_iqd70fH_wawb60Bu8eGOdvZq23emP-b68Hc,37440
|
112
110
|
brainstate/nn/_interaction/_normalizations_test.py,sha256=2p1Jf8nA999VYGWbvOZfKYlKk6UmL0vaEB76xkXxkXw,2438
|
113
111
|
brainstate/nn/_interaction/_poolings.py,sha256=LpwuyeNBVCaVFW7zWc7E-vvlYqx54h46Br5XT6zd_94,47020
|
114
112
|
brainstate/nn/_interaction/_poolings_test.py,sha256=wmd5PngZ3E9tNyF3s0xk-DoDR5yFqpTi9A6nbNoIqn4,7429
|
@@ -136,10 +134,10 @@ brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14
|
|
136
134
|
brainstate/util/_pretty_repr.py,sha256=NYEBCo2iz9Potx-IR7uZZzt2aLQW_94vH79fGusiC2A,5737
|
137
135
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
138
136
|
brainstate/util/_struct.py,sha256=0exv0oOiSt1hmx20Y4J2-pCGtCTx13WcAlEYSBkyung,17640
|
139
|
-
brainstate/util/_tracers.py,sha256
|
137
|
+
brainstate/util/_tracers.py,sha256=0r5T4nhxMzI79NtqroqitsdMT4YfpgV5RdYJLS5uJ0w,2285
|
140
138
|
brainstate/util/_visualization.py,sha256=n4ZVz10z7VBqA0cKO6vyHwEMprWJgPeEqtITzDMai2Y,1519
|
141
|
-
brainstate-0.1.0.
|
142
|
-
brainstate-0.1.0.
|
143
|
-
brainstate-0.1.0.
|
144
|
-
brainstate-0.1.0.
|
145
|
-
brainstate-0.1.0.
|
139
|
+
brainstate-0.1.0.post20241209.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
140
|
+
brainstate-0.1.0.post20241209.dist-info/METADATA,sha256=gXsiYWSQqOJ0CWKINESG4sSpnDkcmVYgWJWeEFLTHoA,3401
|
141
|
+
brainstate-0.1.0.post20241209.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
142
|
+
brainstate-0.1.0.post20241209.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
143
|
+
brainstate-0.1.0.post20241209.dist-info/RECORD,,
|
benchmark/COBA_2005.py
DELETED
@@ -1,125 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
#
|
17
|
-
# Implementation of the paper:
|
18
|
-
#
|
19
|
-
# - Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007),
|
20
|
-
# Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
|
21
|
-
#
|
22
|
-
# which is based on the balanced network proposed by:
|
23
|
-
#
|
24
|
-
# - Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
|
25
|
-
#
|
26
|
-
import os
|
27
|
-
import sys
|
28
|
-
|
29
|
-
sys.path.append('../')
|
30
|
-
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.99'
|
31
|
-
os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
|
32
|
-
|
33
|
-
|
34
|
-
import jax
|
35
|
-
import brainunit as u
|
36
|
-
import time
|
37
|
-
import brainstate as bst
|
38
|
-
|
39
|
-
|
40
|
-
class EINet(bst.nn.DynamicsGroup):
|
41
|
-
def __init__(self, scale):
|
42
|
-
super().__init__()
|
43
|
-
self.n_exc = int(3200 * scale)
|
44
|
-
self.n_inh = int(800 * scale)
|
45
|
-
self.num = self.n_exc + self.n_inh
|
46
|
-
self.N = bst.nn.LIFRef(self.num, V_rest=-60. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
|
47
|
-
tau=20. * u.ms, tau_ref=5. * u.ms,
|
48
|
-
V_initializer=bst.init.Normal(-55., 2., unit=u.mV))
|
49
|
-
self.E = bst.nn.AlignPostProj(
|
50
|
-
comm=bst.event.FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=0.6 * u.mS),
|
51
|
-
syn=bst.nn.Expon.desc(self.num, tau=5. * u.ms),
|
52
|
-
out=bst.nn.COBA.desc(E=0. * u.mV),
|
53
|
-
post=self.N
|
54
|
-
)
|
55
|
-
self.I = bst.nn.AlignPostProj(
|
56
|
-
comm=bst.event.FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=6.7 * u.mS),
|
57
|
-
syn=bst.nn.Expon.desc(self.num, tau=10. * u.ms),
|
58
|
-
out=bst.nn.COBA.desc(E=-80. * u.mV),
|
59
|
-
post=self.N
|
60
|
-
)
|
61
|
-
|
62
|
-
def init_state(self, *args, **kwargs):
|
63
|
-
self.rate = bst.ShortTermState(u.math.zeros(self.num))
|
64
|
-
|
65
|
-
def update(self, t, inp):
|
66
|
-
with bst.environ.context(t=t):
|
67
|
-
spk = self.N.get_spike() != 0.
|
68
|
-
self.E(spk[:self.n_exc])
|
69
|
-
self.I(spk[self.n_exc:])
|
70
|
-
self.N(inp)
|
71
|
-
self.rate.value += self.N.get_spike()
|
72
|
-
|
73
|
-
|
74
|
-
@bst.compile.jit(static_argnums=0)
|
75
|
-
def run(scale: float):
|
76
|
-
# network
|
77
|
-
net = EINet(scale)
|
78
|
-
bst.nn.init_all_states(net)
|
79
|
-
|
80
|
-
duration = 1e4 * u.ms
|
81
|
-
# simulation
|
82
|
-
with bst.environ.context(dt=0.1 * u.ms):
|
83
|
-
times = u.math.arange(0. * u.ms, duration, bst.environ.get_dt())
|
84
|
-
bst.compile.for_loop(lambda t: net.update(t, 20. * u.mA), times)
|
85
|
-
|
86
|
-
return net.num, net.rate.value.sum() / net.num / duration.to_decimal(u.second)
|
87
|
-
|
88
|
-
|
89
|
-
for s in [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
|
90
|
-
jax.block_until_ready(run(s))
|
91
|
-
|
92
|
-
t0 = time.time()
|
93
|
-
n, rate = jax.block_until_ready(run(s))
|
94
|
-
t1 = time.time()
|
95
|
-
print(f'scale={s}, size={n}, time = {t1 - t0} s, firing rate = {rate} Hz')
|
96
|
-
|
97
|
-
|
98
|
-
# A6000 NVIDIA GPU
|
99
|
-
|
100
|
-
# scale=1, size=4000, time = 2.659956455230713 s, firing rate = 50.62445068359375 Hz
|
101
|
-
# scale=2, size=8000, time = 2.7318649291992188 s, firing rate = 50.613040924072266 Hz
|
102
|
-
# scale=4, size=16000, time = 2.807222604751587 s, firing rate = 50.60573959350586 Hz
|
103
|
-
# scale=6, size=24000, time = 3.026782512664795 s, firing rate = 50.60918045043945 Hz
|
104
|
-
# scale=8, size=32000, time = 3.1258811950683594 s, firing rate = 50.607574462890625 Hz
|
105
|
-
# scale=10, size=40000, time = 3.172346353530884 s, firing rate = 50.60942840576172 Hz
|
106
|
-
# scale=20, size=80000, time = 3.751189947128296 s, firing rate = 50.612369537353516 Hz
|
107
|
-
# scale=40, size=160000, time = 5.0217814445495605 s, firing rate = 50.617958068847656 Hz
|
108
|
-
# scale=60, size=240000, time = 7.002646207809448 s, firing rate = 50.61948776245117 Hz
|
109
|
-
# scale=80, size=320000, time = 9.384576320648193 s, firing rate = 50.618499755859375 Hz
|
110
|
-
# scale=100, size=400000, time = 11.69654369354248 s, firing rate = 50.61605453491211 Hz
|
111
|
-
|
112
|
-
|
113
|
-
# AMD Ryzen 7 7840HS
|
114
|
-
|
115
|
-
# scale=1, size=4000, time = 4.436027526855469 s, firing rate = 50.6119270324707 Hz
|
116
|
-
# scale=2, size=8000, time = 8.349745273590088 s, firing rate = 50.612266540527344 Hz
|
117
|
-
# scale=4, size=16000, time = 16.39163303375244 s, firing rate = 50.61349105834961 Hz
|
118
|
-
# scale=6, size=24000, time = 15.725558042526245 s, firing rate = 50.6125602722168 Hz
|
119
|
-
# scale=8, size=32000, time = 21.31995177268982 s, firing rate = 50.61244583129883 Hz
|
120
|
-
# scale=10, size=40000, time = 27.811061143875122 s, firing rate = 50.61423873901367 Hz
|
121
|
-
# scale=20, size=80000, time = 45.54235219955444 s, firing rate = 50.61320877075195 Hz
|
122
|
-
# scale=40, size=160000, time = 82.22228026390076 s, firing rate = 50.61309814453125 Hz
|
123
|
-
# scale=60, size=240000, time = 125.44037556648254 s, firing rate = 50.613094329833984 Hz
|
124
|
-
# scale=80, size=320000, time = 171.20458459854126 s, firing rate = 50.613365173339844 Hz
|
125
|
-
# scale=100, size=400000, time = 215.4547393321991 s, firing rate = 50.6129150390625 Hz
|
benchmark/CUBA_2005.py
DELETED
@@ -1,149 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
#
|
17
|
-
# Implementation of the paper:
|
18
|
-
#
|
19
|
-
# - Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007),
|
20
|
-
# Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
|
21
|
-
#
|
22
|
-
# which is based on the balanced network proposed by:
|
23
|
-
#
|
24
|
-
# - Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
|
25
|
-
#
|
26
|
-
|
27
|
-
import os
|
28
|
-
import sys
|
29
|
-
|
30
|
-
sys.path.append('../')
|
31
|
-
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.99'
|
32
|
-
os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
|
33
|
-
|
34
|
-
|
35
|
-
import jax
|
36
|
-
import time
|
37
|
-
|
38
|
-
import brainunit as u
|
39
|
-
|
40
|
-
import brainstate as bst
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
class FixedProb(bst.nn.Module):
|
45
|
-
def __init__(self, n_pre, n_post, prob, weight):
|
46
|
-
super().__init__()
|
47
|
-
self.prob = prob
|
48
|
-
self.weight = weight
|
49
|
-
self.n_pre = n_pre
|
50
|
-
self.n_post = n_post
|
51
|
-
|
52
|
-
self.mask = bst.random.rand(n_pre, n_post) < prob
|
53
|
-
|
54
|
-
def update(self, x):
|
55
|
-
return (x @ self.mask) * self.weight
|
56
|
-
|
57
|
-
|
58
|
-
class EINet(bst.nn.DynamicsGroup):
|
59
|
-
def __init__(self, scale=1.0):
|
60
|
-
super().__init__()
|
61
|
-
self.n_exc = int(3200 * scale)
|
62
|
-
self.n_inh = int(800 * scale)
|
63
|
-
self.num = self.n_exc + self.n_inh
|
64
|
-
self.N = bst.nn.LIFRef(
|
65
|
-
self.num, V_rest=-49. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
|
66
|
-
tau=20. * u.ms, tau_ref=5. * u.ms,
|
67
|
-
V_initializer=bst.init.Normal(-55., 2., unit=u.mV)
|
68
|
-
)
|
69
|
-
self.E = bst.nn.AlignPostProj(
|
70
|
-
comm=bst.event.FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=1.62 * u.mS),
|
71
|
-
# comm=FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=1.62 * u.mS),
|
72
|
-
syn=bst.nn.Expon.desc(self.num, tau=5. * u.ms),
|
73
|
-
out=bst.nn.CUBA.desc(scale=u.volt),
|
74
|
-
post=self.N
|
75
|
-
)
|
76
|
-
self.I = bst.nn.AlignPostProj(
|
77
|
-
comm=bst.event.FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=-9.0 * u.mS),
|
78
|
-
# comm=FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=-9.0 * u.mS),
|
79
|
-
syn=bst.nn.Expon.desc(self.num, tau=10. * u.ms),
|
80
|
-
out=bst.nn.CUBA.desc(scale=u.volt),
|
81
|
-
post=self.N
|
82
|
-
)
|
83
|
-
|
84
|
-
def init_state(self, *args, **kwargs):
|
85
|
-
self.rate = bst.ShortTermState(u.math.zeros(self.num))
|
86
|
-
|
87
|
-
def update(self, t, inp):
|
88
|
-
with bst.environ.context(t=t):
|
89
|
-
spk = self.N.get_spike()
|
90
|
-
self.E(spk[:self.n_exc])
|
91
|
-
self.I(spk[self.n_exc:])
|
92
|
-
self.N(inp)
|
93
|
-
self.rate.value += self.N.get_spike()
|
94
|
-
|
95
|
-
|
96
|
-
@bst.compile.jit(static_argnums=0)
|
97
|
-
def run(scale: float):
|
98
|
-
# network
|
99
|
-
net = EINet(scale)
|
100
|
-
bst.nn.init_all_states(net)
|
101
|
-
|
102
|
-
duration = 1e4 * u.ms
|
103
|
-
# simulation
|
104
|
-
with bst.environ.context(dt=0.1 * u.ms):
|
105
|
-
times = u.math.arange(0. * u.ms, duration, bst.environ.get_dt())
|
106
|
-
bst.compile.for_loop(lambda t: net.update(t, 20. * u.mA), times,
|
107
|
-
# pbar=bst.compile.ProgressBar(100)
|
108
|
-
)
|
109
|
-
|
110
|
-
return net.num, net.rate.value.sum() / net.num / duration.to_decimal(u.second)
|
111
|
-
|
112
|
-
|
113
|
-
for s in [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
|
114
|
-
jax.block_until_ready(run(s))
|
115
|
-
|
116
|
-
t0 = time.time()
|
117
|
-
n, rate = jax.block_until_ready(run(s))
|
118
|
-
t1 = time.time()
|
119
|
-
print(f'scale={s}, size={n}, time = {t1 - t0} s, firing rate = {rate} Hz')
|
120
|
-
|
121
|
-
|
122
|
-
# A6000 NVIDIA GPU
|
123
|
-
|
124
|
-
# scale=1, size=4000, time = 2.6354849338531494 s, firing rate = 24.982027053833008 Hz
|
125
|
-
# scale=2, size=8000, time = 2.6781561374664307 s, firing rate = 23.719463348388672 Hz
|
126
|
-
# scale=4, size=16000, time = 2.7448785305023193 s, firing rate = 24.592931747436523 Hz
|
127
|
-
# scale=6, size=24000, time = 2.8237478733062744 s, firing rate = 24.159996032714844 Hz
|
128
|
-
# scale=8, size=32000, time = 2.9344418048858643 s, firing rate = 24.956790924072266 Hz
|
129
|
-
# scale=10, size=40000, time = 3.042517900466919 s, firing rate = 23.644424438476562 Hz
|
130
|
-
# scale=20, size=80000, time = 3.6727631092071533 s, firing rate = 24.226743698120117 Hz
|
131
|
-
# scale=40, size=160000, time = 4.857396602630615 s, firing rate = 24.329742431640625 Hz
|
132
|
-
# scale=60, size=240000, time = 6.812030792236328 s, firing rate = 24.370006561279297 Hz
|
133
|
-
# scale=80, size=320000, time = 9.227966547012329 s, firing rate = 24.41067886352539 Hz
|
134
|
-
# scale=100, size=400000, time = 11.405697584152222 s, firing rate = 24.32524871826172 Hz
|
135
|
-
|
136
|
-
|
137
|
-
# AMD Ryzen 7 7840HS
|
138
|
-
|
139
|
-
# scale=1, size=4000, time = 1.1661601066589355 s, firing rate = 22.438201904296875 Hz
|
140
|
-
# scale=2, size=8000, time = 3.3255884647369385 s, firing rate = 23.868364334106445 Hz
|
141
|
-
# scale=4, size=16000, time = 6.950139999389648 s, firing rate = 24.21693229675293 Hz
|
142
|
-
# scale=6, size=24000, time = 10.011993169784546 s, firing rate = 24.240270614624023 Hz
|
143
|
-
# scale=8, size=32000, time = 13.027734518051147 s, firing rate = 24.753198623657227 Hz
|
144
|
-
# scale=10, size=40000, time = 16.449942350387573 s, firing rate = 24.7176570892334 Hz
|
145
|
-
# scale=20, size=80000, time = 30.754598140716553 s, firing rate = 24.119956970214844 Hz
|
146
|
-
# scale=40, size=160000, time = 63.6387836933136 s, firing rate = 24.72784996032715 Hz
|
147
|
-
# scale=60, size=240000, time = 78.58532166481018 s, firing rate = 24.402742385864258 Hz
|
148
|
-
# scale=80, size=320000, time = 102.4250214099884 s, firing rate = 24.59092140197754 Hz
|
149
|
-
# scale=100, size=400000, time = 145.35173273086548 s, firing rate = 24.33751106262207 Hz
|
File without changes
|
File without changes
|