brainstate 0.1.0__py2.py3-none-any.whl → 0.1.0.post20241125__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/_state.py +1 -1
  4. brainstate/augment/_autograd.py +121 -120
  5. brainstate/augment/_autograd_test.py +97 -0
  6. brainstate/event/__init__.py +10 -8
  7. brainstate/event/_csr_benchmark.py +14 -0
  8. brainstate/event/{_csr.py → _csr_mv.py} +26 -18
  9. brainstate/event/_csr_mv_benchmark.py +14 -0
  10. brainstate/event/_fixedprob_mv.py +708 -0
  11. brainstate/event/_fixedprob_mv_benchmark.py +128 -0
  12. brainstate/event/{_fixed_probability_test.py → _fixedprob_mv_test.py} +13 -10
  13. brainstate/event/_linear_mv.py +359 -0
  14. brainstate/event/_linear_mv_benckmark.py +82 -0
  15. brainstate/event/{_linear_test.py → _linear_mv_test.py} +9 -4
  16. brainstate/event/_xla_custom_op.py +309 -0
  17. brainstate/event/_xla_custom_op_test.py +55 -0
  18. brainstate/nn/_dyn_impl/_dynamics_synapse.py +6 -11
  19. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -1
  20. brainstate/nn/_dynamics/_projection_base.py +1 -1
  21. brainstate/nn/_exp_euler.py +1 -1
  22. brainstate/nn/_interaction/__init__.py +13 -4
  23. brainstate/nn/_interaction/{_connections.py → _conv.py} +0 -227
  24. brainstate/nn/_interaction/{_connections_test.py → _conv_test.py} +0 -15
  25. brainstate/nn/_interaction/_linear.py +582 -0
  26. brainstate/nn/_interaction/_linear_test.py +42 -0
  27. brainstate/optim/_lr_scheduler.py +1 -1
  28. brainstate/optim/_optax_optimizer.py +19 -0
  29. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/METADATA +2 -2
  30. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/RECORD +34 -24
  31. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/top_level.txt +1 -0
  32. brainstate/event/_fixed_probability.py +0 -271
  33. brainstate/event/_linear.py +0 -219
  34. /brainstate/event/{_csr_test.py → _csr_mv_test.py} +0 -0
  35. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/LICENSE +0 -0
  36. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/WHEEL +0 -0
@@ -1,5 +1,7 @@
1
+ benchmark/COBA_2005.py,sha256=Q8PsZ0lxu14jsF3bCtlZW35iQB8S2_oFEUYQzK2hPiA,5561
2
+ benchmark/CUBA_2005.py,sha256=_W94yOMh2ueqblU4ItEPeTLwHF0_lbEWlVNEBy0Tix0,6222
1
3
  brainstate/__init__.py,sha256=r7C3eLTg8LEusoH6PGgBFFt4ZgbketYLoLA0lQhUCsE,2098
2
- brainstate/_state.py,sha256=L-wDWR04ON1VAqBDv0cHdGM-DVK9lHdH1p2LCKMWoJA,28802
4
+ brainstate/_state.py,sha256=0h3B32130Tvv41eeolYTbEGT9FZ9WUZ9yYoGyVnte_c,28808
3
5
  brainstate/_state_test.py,sha256=1boTp1w8DiCFLsPwNtlLrlIqGRpkasAmLid5bv2fgP4,2223
4
6
  brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
5
7
  brainstate/environ.py,sha256=G6r_rqfbofRbjFFalRu_DHaL7ruFTeLRXBQDXM6P-tQ,17477
@@ -10,8 +12,8 @@ brainstate/surrogate.py,sha256=YaY6RJ6kzpuPXWFjaWsxWt2MzJfdm5v_jeOR8V_jPoU,48369
10
12
  brainstate/transform.py,sha256=uryXL2vAD-oApUNvL4n6uzLPRwXXARCJ6lRwco_Gh_Y,772
11
13
  brainstate/typing.py,sha256=XaqpkfkMxBYkze5DoCQ48GSxuyO0C8rGOKeoN9qeZNo,10425
12
14
  brainstate/augment/__init__.py,sha256=BtXIBel7GbttmfBX6grxOxl0IiOJxLEa7qCGAXumamE,1286
13
- brainstate/augment/_autograd.py,sha256=4pbiAYCwnBuD5xpRo50HU5DvpZmnaB1Z2L-1TVY4OAE,25357
14
- brainstate/augment/_autograd_test.py,sha256=WIgUCENzW7pdtsR0PCHjOJnAz4eI1Ha9wlQsS6V8Yjc,41283
15
+ brainstate/augment/_autograd.py,sha256=o9ivoEY7BmtdM1XmzdMmeRXpj6Tvn5xNB8LSGp2HKC8,25238
16
+ brainstate/augment/_autograd_test.py,sha256=S2eEgrwTzdSi3u2nKE3u37WSThosLwx1WCP9ptJAGKo,44060
15
17
  brainstate/augment/_eval_shape.py,sha256=dGlRVHOAZ9LSRZsFi1erxgEWHrnhBO3Kq3WW11-Hvng,3819
16
18
  brainstate/augment/_eval_shape_test.py,sha256=1nnxbU7hPRbZPQWNWbQ518pw-H7FGDKKnQpZGBY9uRI,1390
17
19
  brainstate/augment/_mapping.py,sha256=cpxzVGCEYnP5jPqrowYoPXciw_-QR2F3wggrRj1OCPc,21850
@@ -35,14 +37,20 @@ brainstate/compile/_make_jaxpr_test.py,sha256=qJUtkyj50JQ6f4UJbOLhvRdkbNn3NSKibF
35
37
  brainstate/compile/_progress_bar.py,sha256=LML4DjrLSIeGYJWLjqy6BnHSz03fu1gnjf-7kljP384,3824
36
38
  brainstate/compile/_unvmap.py,sha256=ewbLLNXiI_dBsEBaVzSS0BEXNol22sd9gMzk606lSkM,4139
37
39
  brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
38
- brainstate/event/__init__.py,sha256=83a1IWZ_Oma1XhifEyt_l7jnEmRB-Q4Vd-t0PIolGkQ,1058
39
- brainstate/event/_csr.py,sha256=rEBPLnGCmf8kslB0UFsOnMjeg8s1oOt-a6mJrii3b_M,10467
40
- brainstate/event/_csr_test.py,sha256=qx3TZKnvIC5grSubDFRp7vubu2RJGeCDZFPmof-bRiA,3874
41
- brainstate/event/_fixed_probability.py,sha256=kZYghzEopb2lQPO3RnCxtta80hfVM-9sdSJ6gzKlcjM,9656
42
- brainstate/event/_fixed_probability_test.py,sha256=UFeb5ey2mLFaIPvgwrEE2HUjAIRwb3MuHsnnlfqNs9w,4119
43
- brainstate/event/_linear.py,sha256=kgSA9xzIMzUAl_ztKHuDG5gUZsBCr1AgvuHkAdkEyH4,7006
44
- brainstate/event/_linear_test.py,sha256=AKttUDxlFvdeN8j3EKSGDWaB629ANYRcDg2G_n5Tbkc,3577
40
+ brainstate/event/__init__.py,sha256=wOBkq7kDg90M8Y9FuoXRlSEuu1ZzbIhCJ1dHeLqN6_Q,1194
41
+ brainstate/event/_csr_benchmark.py,sha256=xrj2DSWzw0pUHAE1jRBeSRhMW7ogXvDHEdeaZGioNE4,702
42
+ brainstate/event/_csr_mv.py,sha256=4PVSK6QuuDK6pfA5SXvUU4Cxwkd-alGJ2u7VI0irwaQ,10718
43
+ brainstate/event/_csr_mv_benchmark.py,sha256=xrj2DSWzw0pUHAE1jRBeSRhMW7ogXvDHEdeaZGioNE4,702
44
+ brainstate/event/_csr_mv_test.py,sha256=qx3TZKnvIC5grSubDFRp7vubu2RJGeCDZFPmof-bRiA,3874
45
+ brainstate/event/_fixedprob_mv.py,sha256=uFJBlE3hh_QHl3TvNy061xMuLCykuAJGOP3kc-YfH2w,25187
46
+ brainstate/event/_fixedprob_mv_benchmark.py,sha256=_F_8fH5MNMJZHeSqnq9DYMI9OgYr6JIxBKjbsgeWRv4,4720
47
+ brainstate/event/_fixedprob_mv_test.py,sha256=jijrtJ5fnwcLmA7Tjd3vDlzwfbftmLoVTN4-MPuogVc,4201
48
+ brainstate/event/_linear_mv.py,sha256=n1NxBUQHWPHBp_xjwpLvtKW8RMvjR5cRjdV_fh_2gGo,11788
49
+ brainstate/event/_linear_mv_benckmark.py,sha256=hu0WqYMIa3jMoH7Fq9dgxcBjjXGFhghPx9vztyCo1KY,2411
50
+ brainstate/event/_linear_mv_test.py,sha256=V9w41ZP2vu95CyCdCkm-j9Eftqs2kqmeBY809N1-syY,3736
45
51
  brainstate/event/_misc.py,sha256=8IpPooXjF2m0-tuo3pGHqThq2yLSNmYziy_zdurZ3NI,1040
52
+ brainstate/event/_xla_custom_op.py,sha256=QB4jz_fUEPF-efJCVKAxwx8U79AqdcKoEg2QrGwot8I,10864
53
+ brainstate/event/_xla_custom_op_test.py,sha256=rnkGMleXzLfJj4y5QqwfBvCCLTAHe_uabwBDniY-URM,1745
46
54
  brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
47
55
  brainstate/functional/_activations.py,sha256=tsCGoV_35IyZ0mM6_4fV9v0-Vj3V9Qm55U8wSpZ_E4o,22180
48
56
  brainstate/functional/_activations_test.py,sha256=T___RlSrIfXwlkw8dg5A9EZMTZGDzv3a2evUwq_nYFg,13034
@@ -66,7 +74,7 @@ brainstate/init/_regular_inits.py,sha256=DmVMajugfyYFNUMzgFdDKMvbBu9hMWxkfDd-50u
66
74
  brainstate/init/_regular_inits_test.py,sha256=tJl4aOkclllJIfKzJTbc0cfYCw2SoBsx8_G123RnqbU,1842
67
75
  brainstate/nn/__init__.py,sha256=rxURT8J1XfBn3Vh3Dx_WzVADWn9zVriIty5KZEG-x6o,1622
68
76
  brainstate/nn/_collective_ops.py,sha256=BzKKUxeBPtZojHqvOzgPkz9EFaVJUwQWmqR18mF2l38,6233
69
- brainstate/nn/_exp_euler.py,sha256=3Q50V9TlDU-peSu0Exr9cjuO8nGZxrF-qljH5qxT6yM,3520
77
+ brainstate/nn/_exp_euler.py,sha256=yjkfSllFxGWKEAlHo5AzBizzkFj6FEVDKmFV6E2g214,3521
70
78
  brainstate/nn/_exp_euler_test.py,sha256=clwRD8QR71k1jn6NrACMDEUcFMh0J9RTosoPnlYWUkw,1242
71
79
  brainstate/nn/_module.py,sha256=HDLPvLfB7jat2VT3gBu0MxA7vfzK7xgowemitHX8Cgo,10835
72
80
  brainstate/nn/_module_test.py,sha256=V4ZhiY_zYPvArkB2eeOTtZcgQrtlRyXKMbS1AJH4vC8,8893
@@ -74,18 +82,18 @@ brainstate/nn/metrics.py,sha256=iupHjSRTHYY-HmEPBC4tXWrZfF4zh1ek2NwSAA0gnwE,1473
74
82
  brainstate/nn/_dyn_impl/__init__.py,sha256=Oazar7h89dp1WA2Vx4Tj7gCBhxJKH4LAUEABkBEG7vU,1462
75
83
  brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2K8k5FAcf3Pa5N8,10927
76
84
  brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
77
- brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=UYRdbXsuLRS_PWuF1hGXUIZEC-g0uuVzp3VSyF1rQxk,11948
85
+ brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=BP-ko0FyjWZopuUhAy3Ot3wWRQlGcpumWJpKrQakqok,11869
78
86
  brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
79
87
  brainstate/nn/_dyn_impl/_inputs.py,sha256=6eZKnkmrM0Gog2fpSKjSnwnQvhbFYhG4q9Vuo-GH2LI,5050
80
88
  brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
81
- brainstate/nn/_dyn_impl/_rate_rnns.py,sha256=TJOD3av5gnDGZlqcrxRo4chuCixrDhUsYbmPEhU-608,14773
89
+ brainstate/nn/_dyn_impl/_rate_rnns.py,sha256=dz_yT_6hJVhKulcjIARbGtmMzZqISws96CtBc6o5GOo,14768
82
90
  brainstate/nn/_dyn_impl/_rate_rnns_test.py,sha256=gNgtr-a4ZiU1XF9wFG1HiJ9fLosfWchVR9Zn1x39xt4,2452
83
91
  brainstate/nn/_dyn_impl/_readout.py,sha256=iYk2lKkB42OClLUlXQVr8SIqL4NzwZzVE3rlEAExGvw,4370
84
92
  brainstate/nn/_dyn_impl/_readout_test.py,sha256=R9JJPRvy3mAHSv8n1Hzjk2kBSDjBzJNbS83ystll86s,2109
85
93
  brainstate/nn/_dynamics/__init__.py,sha256=j1HSWu01wf5-KjSaNhBC9utVGDALOhUsFPrLPcPPDsM,1208
86
94
  brainstate/nn/_dynamics/_dynamics_base.py,sha256=vOOi7lQQmfVUVubm_1G2Xj3kAd4S9FFtbUDKxDx97Kg,21637
87
95
  brainstate/nn/_dynamics/_dynamics_base_test.py,sha256=gXMwENqqSvyZbMpLP0QtYndJ_h39dF5gIeiiSbMAjTk,2721
88
- brainstate/nn/_dynamics/_projection_base.py,sha256=FXeD8UTdoEtm1_CG8ZOtzfnD8Cx6BijsxwqhG4ewdJg,12596
96
+ brainstate/nn/_dynamics/_projection_base.py,sha256=nSdaA1UgCormDHN7LFebhHl7CjWUuJ4b3zeJr0vq5Ms,12597
89
97
  brainstate/nn/_dynamics/_state_delay.py,sha256=nZYGmVKmQvAQu-W4YOUFH1gnr-ZS3rg_GNkRhI8rQ-I,16761
90
98
  brainstate/nn/_dynamics/_synouts.py,sha256=9TGAc-nVa50th7KKn4oKLbro-4W4rwxYvp-eu7ksAIE,4491
91
99
  brainstate/nn/_dynamics/_synouts_test.py,sha256=V_jDswRN4VvEXD-2yJO3VA1TALgX0HK6oPBQiUntOWc,2266
@@ -94,19 +102,21 @@ brainstate/nn/_elementwise/_dropout.py,sha256=jxqce-lcH5P0AotHZKhdp_Ho1n7qUuvSe6
94
102
  brainstate/nn/_elementwise/_dropout_test.py,sha256=Qn7xqZOyZMPCGF6tFjTiPId0yELOXjSsW5-hgihP3fE,4383
95
103
  brainstate/nn/_elementwise/_elementwise.py,sha256=om-KpwDTk5yFG5KBYXXHquRLV7s28_FJjk-omvyMyvQ,33342
96
104
  brainstate/nn/_elementwise/_elementwise_test.py,sha256=SZI9jB39sZ5SO1dpWGW-PhodthwN0GU9FY1nqf2fWcs,5341
97
- brainstate/nn/_interaction/__init__.py,sha256=BmiJxuORxiGoziApKlcMTQ-Q5TyV57p5yGKFxNanDCQ,1196
98
- brainstate/nn/_interaction/_connections.py,sha256=46OYPtLvJZHhJcA6ZgxYKdqnjR22xftac4PQSnoBj0w,26382
99
- brainstate/nn/_interaction/_connections_test.py,sha256=53GHjv_YieAzwht4F9kpYZ83a5bOaLa4WRHa7e2iZf4,9076
105
+ brainstate/nn/_interaction/__init__.py,sha256=TTY_SeNrdx4VnUSw6vdyl02OHdS9Qs15cWBp6kjsyNQ,1289
106
+ brainstate/nn/_interaction/_conv.py,sha256=LgWO4TeKRru07UEUga3YX6xog6WHtOvKdKtgxGnHUvw,18512
107
+ brainstate/nn/_interaction/_conv_test.py,sha256=fHXRFYnDghFiKre63RqMwIE_gbPKdK34UPhKOz-J3qU,8695
100
108
  brainstate/nn/_interaction/_embedding.py,sha256=iK0I1ExKWFa_QzV9UDGj32Ljsmdr1g_LlAtMcusebxU,2187
109
+ brainstate/nn/_interaction/_linear.py,sha256=bjiWGJCe81ugQQOykwjWlLW5uhe0CHWwkPA20a4n5YQ,21340
110
+ brainstate/nn/_interaction/_linear_test.py,sha256=KlvFZA0rpyaspf4LT4K7u-RR5jCEB_q1WReqAw9sFcU,1274
101
111
  brainstate/nn/_interaction/_normalizations.py,sha256=j0sI8prfK65daQBizPFBHbs57zfGXzlnN2DxKqoN8Wk,14842
102
112
  brainstate/nn/_interaction/_normalizations_test.py,sha256=2p1Jf8nA999VYGWbvOZfKYlKk6UmL0vaEB76xkXxkXw,2438
103
113
  brainstate/nn/_interaction/_poolings.py,sha256=LpwuyeNBVCaVFW7zWc7E-vvlYqx54h46Br5XT6zd_94,47020
104
114
  brainstate/nn/_interaction/_poolings_test.py,sha256=wmd5PngZ3E9tNyF3s0xk-DoDR5yFqpTi9A6nbNoIqn4,7429
105
115
  brainstate/optim/__init__.py,sha256=7Ao0LCtDNAoxSRSXiLLKnd1_4mR2GSExizpN38il-Fo,1195
106
116
  brainstate/optim/_base.py,sha256=NbP3fzVslfnmJAOWAPD7o9TDWeRdw4CRdNfnfkMTfkU,1873
107
- brainstate/optim/_lr_scheduler.py,sha256=4E6u86C70QTMcIipZ22YoYLQM5R3d90hCgD9FTNJZb4,15324
117
+ brainstate/optim/_lr_scheduler.py,sha256=Tw-aH5knnMh9eAi2LdxkZ6cAk5DbEn0tW0C738xqjFA,15325
108
118
  brainstate/optim/_lr_scheduler_test.py,sha256=_amijy9WzuvVWRC4GiuzyaC_278QG97EYZ1WtsE2IyA,1778
109
- brainstate/optim/_optax_optimizer.py,sha256=yctoMbFdtgaih2X6JqhbxHF575EHS-lMZ8X_Y6_USVY,4825
119
+ brainstate/optim/_optax_optimizer.py,sha256=SuXV_xUBfhOw1_C2J5TIpy3dXDtI9VJFaSMLy8hLcXE,5312
110
120
  brainstate/optim/_optax_optimizer_test.py,sha256=DAomE8Eu3dn4gh1S3EZ_u4pW4rhcl16vWPbnDcN3Rs4,1762
111
121
  brainstate/optim/_sgd_optimizer.py,sha256=NVKYhGcw2D1ksNWUIXZcj-74LUaan8XL3EERk-EHMRI,46008
112
122
  brainstate/random/__init__.py,sha256=c5q-RC3grRIjx-HBb2IhKZpi_xzbFmUUxzRAzqfREic,1045
@@ -128,8 +138,8 @@ brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7
128
138
  brainstate/util/_struct.py,sha256=0exv0oOiSt1hmx20Y4J2-pCGtCTx13WcAlEYSBkyung,17640
129
139
  brainstate/util/_tracers.py,sha256=-sX76GJRThdSpDJBejAIzDdBbVhmH6kb-1WoDJVI7V0,2556
130
140
  brainstate/util/_visualization.py,sha256=n4ZVz10z7VBqA0cKO6vyHwEMprWJgPeEqtITzDMai2Y,1519
131
- brainstate-0.1.0.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
132
- brainstate-0.1.0.dist-info/METADATA,sha256=WmOW4lSuLuiDjOBC29ngnWvQ3p7vV2VA0BCluQpt2dE,3388
133
- brainstate-0.1.0.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
134
- brainstate-0.1.0.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
135
- brainstate-0.1.0.dist-info/RECORD,,
141
+ brainstate-0.1.0.post20241125.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
142
+ brainstate-0.1.0.post20241125.dist-info/METADATA,sha256=PV2RbaJa62UxM18IcJFfmHTSt2ClcZi6LnIS1Tbky-4,3401
143
+ brainstate-0.1.0.post20241125.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
144
+ brainstate-0.1.0.post20241125.dist-info/top_level.txt,sha256=MVkn5SZk0qis32_AGxU2Tg1xADfj2IgCNS25CQD7_ng,21
145
+ brainstate-0.1.0.post20241125.dist-info/RECORD,,
@@ -1,271 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from __future__ import annotations
16
-
17
- from typing import Union, Callable, Optional
18
-
19
- import brainunit as u
20
- import jax
21
- import jax.numpy as jnp
22
- import numpy as np
23
-
24
- from brainstate._state import ParamState
25
- from brainstate._utils import set_module_as
26
- from brainstate.compile import for_loop
27
- from brainstate.init import param
28
- from brainstate.nn._module import Module
29
- from brainstate.random import RandomState
30
- from brainstate.typing import ArrayLike
31
- from ._misc import FloatScalar, IntScalar
32
-
33
- __all__ = [
34
- 'FixedProb',
35
- ]
36
-
37
-
38
- class FixedProb(Module):
39
- """
40
- The FixedProb module implements a fixed probability connection with CSR sparse data structure.
41
-
42
- Parameters
43
- ----------
44
- n_pre : int
45
- Number of pre-synaptic neurons.
46
- n_post : int
47
- Number of post-synaptic neurons.
48
- prob : float
49
- Probability of connection.
50
- weight : float or callable or jax.Array or brainunit.Quantity
51
- Maximum synaptic conductance.
52
- allow_multi_conn : bool, optional
53
- Whether multiple connections are allowed from a single pre-synaptic neuron.
54
- Default is True, meaning that a value of ``a`` can be selected multiple times.
55
- prob : float
56
- Probability of connection.
57
- name : str, optional
58
- Name of the module.
59
- """
60
-
61
- __module__ = 'brainstate.event'
62
-
63
- def __init__(
64
- self,
65
- n_pre: IntScalar,
66
- n_post: IntScalar,
67
- prob: FloatScalar,
68
- weight: Union[Callable, ArrayLike],
69
- allow_multi_conn: bool = True,
70
- seed: Optional[int] = None,
71
- name: Optional[str] = None,
72
- grad_mode: str = 'vjp'
73
- ):
74
- super().__init__(name=name)
75
- self.n_pre = n_pre
76
- self.n_post = n_post
77
- self.in_size = n_pre
78
- self.out_size = n_post
79
-
80
- self.n_conn = int(n_post * prob)
81
- if self.n_conn < 1:
82
- raise ValueError(
83
- f"The number of connections must be at least 1. Got: int({n_post} * {prob}) = {self.n_conn}")
84
-
85
- assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
86
- self.grad_mode = grad_mode
87
-
88
- # indices of post connected neurons
89
- if allow_multi_conn:
90
- self.indices = np.random.RandomState(seed).randint(0, n_post, size=(self.n_pre, self.n_conn))
91
- else:
92
- rng = RandomState(seed)
93
- self.indices = for_loop(lambda i: rng.choice(n_post, size=(self.n_conn,), replace=False), np.arange(n_pre))
94
- self.indices = u.math.asarray(self.indices)
95
-
96
- # maximum synaptic conductance
97
- weight = param(weight, (self.n_pre, self.n_conn), allow_none=False)
98
- self.weight = ParamState(weight)
99
-
100
- def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
101
- device_kind = jax.devices()[0].platform # spk.device.device_kind
102
- if device_kind == 'cpu':
103
- return cpu_fixed_prob(self.indices,
104
- u.math.asarray(self.weight.value),
105
- u.math.asarray(spk),
106
- n_post=self.n_post,
107
- grad_mode=self.grad_mode)
108
- elif device_kind in ['gpu', 'tpu']:
109
- raise NotImplementedError()
110
- else:
111
- raise ValueError(f"Unsupported device: {device_kind}")
112
-
113
-
114
- @set_module_as('brainstate.event')
115
- def cpu_fixed_prob(
116
- indices: jax.Array,
117
- weight: Union[u.Quantity, jax.Array],
118
- spk: jax.Array,
119
- *,
120
- n_post: int,
121
- grad_mode: str = 'vjp'
122
- ) -> Union[u.Quantity, jax.Array]:
123
- """
124
- The FixedProb module implements a fixed probability connection with CSR sparse data structure.
125
-
126
- Parameters
127
- ----------
128
- n_post : int
129
- Number of post-synaptic neurons.
130
- weight : brainunit.Quantity or jax.Array
131
- Maximum synaptic conductance.
132
- spk : jax.Array
133
- Spike events.
134
- indices : jax.Array
135
- Indices of post connected neurons.
136
- grad_mode : str, optional
137
- Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
138
-
139
- Returns
140
- -------
141
- post_data : brainunit.Quantity or jax.Array
142
- Post synaptic data.
143
- """
144
- unit = u.get_unit(weight)
145
- weight = u.get_mantissa(weight)
146
- indices = jnp.asarray(indices)
147
- spk = jnp.asarray(spk)
148
-
149
- def mv(spk_vector):
150
- assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
151
- if grad_mode == 'vjp':
152
- post_data = _cpu_event_fixed_prob_mv_vjp(indices, weight, spk_vector, n_post)
153
- elif grad_mode == 'jvp':
154
- post_data = _cpu_event_fixed_prob_mv_jvp(indices, weight, spk_vector, n_post)
155
- else:
156
- raise ValueError(f"Unsupported grad_mode: {grad_mode}")
157
- return post_data
158
-
159
- assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
160
- assert weight.ndim in [2, 0], f"weight must be 2D or 0D. Got: {weight.ndim}"
161
- assert indices.ndim == 2, f"indices must be 2D. Got: {indices.ndim}"
162
-
163
- if spk.ndim == 1:
164
- post_data = mv(spk)
165
- else:
166
- shape = spk.shape[:-1]
167
- post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
168
- post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
169
- return u.maybe_decimal(u.Quantity(post_data, unit=unit))
170
-
171
-
172
- # -------------------
173
- # CPU Implementation
174
- # -------------------
175
-
176
-
177
- def _cpu_event_fixed_prob_mv(indices, g_max, spk, n_post: int) -> jax.Array:
178
- def scan_fn(post, i):
179
- w = g_max if jnp.size(g_max) == 1 else g_max[i]
180
- ids = indices[i]
181
- sp = spk[i]
182
- if spk.dtype == jnp.bool_:
183
- post = jax.lax.cond(sp, lambda: post.at[ids].add(w), lambda: post)
184
- else:
185
- post = jax.lax.cond(sp == 0., lambda: post, lambda: post.at[ids].add(w * sp))
186
- return post, None
187
-
188
- return jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=g_max.dtype), np.arange(len(spk)))[0]
189
-
190
-
191
- # --------------
192
- # VJP
193
- # --------------
194
-
195
- def _cpu_event_fixed_prob_mv_fwd(indices, g_max, spk, n_post):
196
- return _cpu_event_fixed_prob_mv(indices, g_max, spk, n_post=n_post), (g_max, spk)
197
-
198
-
199
- def _cpu_event_fixed_prob_mv_bwd(indices, n_post, res, ct):
200
- weight, spk = res
201
-
202
- # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
203
- homo = jnp.size(weight) == 1
204
- if homo: # homogeneous weight
205
- ct_spk = jax.vmap(lambda idx: jnp.sum(ct[idx] * weight))(indices)
206
- else: # heterogeneous weight
207
- ct_spk = jax.vmap(lambda idx, w: jnp.inner(ct[idx], w))(indices, weight)
208
-
209
- # ∂L/∂w = ∂L/∂y * ∂y/∂w
210
- if homo: # scalar
211
- ct_gmax = _cpu_event_fixed_prob_mv(indices, jnp.asarray(1.), spk, n_post=n_post)
212
- ct_gmax = jnp.inner(ct, ct_gmax)
213
- else:
214
- def scan_fn(d_gmax, i):
215
- if spk.dtype == jnp.bool_:
216
- d_gmax = jax.lax.cond(spk[i], lambda: d_gmax.at[i].add(ct[indices[i]]), lambda: d_gmax)
217
- else:
218
- d_gmax = jax.lax.cond(spk[i] == 0., lambda: d_gmax, lambda: d_gmax.at[i].add(ct[indices[i]] * spk[i]))
219
- return d_gmax, None
220
-
221
- ct_gmax = jax.lax.scan(scan_fn, jnp.zeros_like(weight), np.arange(len(spk)))[0]
222
- return ct_gmax, ct_spk
223
-
224
-
225
- _cpu_event_fixed_prob_mv_vjp = jax.custom_vjp(_cpu_event_fixed_prob_mv, nondiff_argnums=(0, 3))
226
- _cpu_event_fixed_prob_mv_vjp.defvjp(_cpu_event_fixed_prob_mv_fwd, _cpu_event_fixed_prob_mv_bwd)
227
-
228
-
229
- # --------------
230
- # JVP
231
- # --------------
232
-
233
-
234
- def _cpu_event_fixed_prob_mv_jvp_rule(indices, n_post, primals, tangents):
235
- # forward pass
236
- weight, spk = primals
237
- y = _cpu_event_fixed_prob_mv(indices, weight, spk, n_post=n_post)
238
-
239
- # forward gradients
240
- gmax_dot, spk_dot = tangents
241
-
242
- # ∂y/∂gmax
243
- dgmax = _cpu_event_fixed_prob_mv(indices, gmax_dot, spk, n_post=n_post)
244
-
245
- def scan_fn(post, i):
246
- ids = indices[i]
247
- w = weight if jnp.size(weight) == 1 else weight[i]
248
- post = post.at[ids].add(w * spk_dot[i])
249
- return post, None
250
-
251
- # ∂y/∂gspk
252
- dspk = jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=weight.dtype), np.arange(len(spk)))[0]
253
- return y, dgmax + dspk
254
-
255
-
256
- _cpu_event_fixed_prob_mv_jvp = jax.custom_jvp(_cpu_event_fixed_prob_mv, nondiff_argnums=(0, 3))
257
- _cpu_event_fixed_prob_mv_jvp.defjvp(_cpu_event_fixed_prob_mv_jvp_rule)
258
-
259
-
260
- def _gpu_event_fixed_prob_mv(indices, g_max, spk, n_post: int) -> jax.Array:
261
- def scan_fn(post, i):
262
- w = g_max if jnp.size(g_max) == 1 else g_max[i]
263
- ids = indices[i]
264
- sp = spk[i]
265
- if spk.dtype == jnp.bool_:
266
- post = jax.lax.cond(sp, lambda: post.at[ids].add(w), lambda: post)
267
- else:
268
- post = jax.lax.cond(sp == 0., lambda: post, lambda: post.at[ids].add(w * sp))
269
- return post, None
270
-
271
- return jax.lax.scan(scan_fn, jnp.zeros((n_post,), dtype=g_max.dtype), np.arange(len(spk)))[0]
@@ -1,219 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from __future__ import annotations
16
-
17
- from typing import Union, Callable, Optional
18
-
19
- import brainunit as u
20
- import jax
21
- import jax.numpy as jnp
22
- import numpy as np
23
-
24
- from brainstate._state import ParamState, State
25
- from brainstate._utils import set_module_as
26
- from brainstate.init import param
27
- from brainstate.nn._module import Module
28
- from brainstate.typing import ArrayLike
29
- from ._misc import IntScalar
30
-
31
- __all__ = [
32
- 'Linear',
33
- ]
34
-
35
-
36
- class Linear(Module):
37
- """
38
- The FixedProb module implements a fixed probability connection with CSR sparse data structure.
39
-
40
- Parameters
41
- ----------
42
- n_pre : int
43
- Number of pre-synaptic neurons.
44
- n_post : int
45
- Number of post-synaptic neurons.
46
- weight : float or callable or jax.Array or brainunit.Quantity
47
- Maximum synaptic conductance.
48
- name : str, optional
49
- Name of the module.
50
- """
51
-
52
- __module__ = 'brainstate.event'
53
-
54
- def __init__(
55
- self,
56
- n_pre: IntScalar,
57
- n_post: IntScalar,
58
- weight: Union[Callable, ArrayLike],
59
- name: Optional[str] = None,
60
- grad_mode: str = 'vjp'
61
- ):
62
- super().__init__(name=name)
63
- self.n_pre = n_pre
64
- self.n_post = n_post
65
- self.in_size = n_pre
66
- self.out_size = n_post
67
-
68
- assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
69
- self.grad_mode = grad_mode
70
-
71
- # maximum synaptic conductance
72
- weight = param(weight, (self.n_pre, self.n_post), allow_none=False)
73
- self.weight = ParamState(weight)
74
-
75
- def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
76
- weight = self.weight.value if isinstance(self.weight, State) else self.weight
77
- if u.math.size(weight) == 1:
78
- return u.math.ones(self.n_post) * (u.math.sum(spk) * weight)
79
-
80
- device_kind = jax.devices()[0].platform # spk.device.device_kind
81
- if device_kind == 'cpu':
82
- return cpu_event_linear(u.math.asarray(weight),
83
- u.math.asarray(spk),
84
- n_post=self.n_post,
85
- grad_mode=self.grad_mode)
86
- elif device_kind in ['gpu', 'tpu']:
87
- raise NotImplementedError()
88
- else:
89
- raise ValueError(f"Unsupported device: {device_kind}")
90
-
91
-
92
- @set_module_as('brainstate.event')
93
- def cpu_event_linear(
94
- g_max: Union[u.Quantity, jax.Array],
95
- spk: jax.Array,
96
- *,
97
- n_post: int = None,
98
- grad_mode: str = 'vjp'
99
- ) -> Union[u.Quantity, jax.Array]:
100
- """
101
- The FixedProb module implements a fixed probability connection with CSR sparse data structure.
102
-
103
- Parameters
104
- ----------
105
- n_post : int
106
- Number of post-synaptic neurons.
107
- g_max : brainunit.Quantity or jax.Array
108
- Maximum synaptic conductance.
109
- spk : jax.Array
110
- Spike events.
111
- grad_mode : str, optional
112
- Gradient mode. Default is 'vjp'. Can be 'vjp' or 'jvp'.
113
-
114
- Returns
115
- -------
116
- post_data : brainunit.Quantity or jax.Array
117
- Post synaptic data.
118
- """
119
- unit = u.get_unit(g_max)
120
- g_max = u.get_mantissa(g_max)
121
- spk = jnp.asarray(spk)
122
-
123
- def mv(spk_vector):
124
- assert spk_vector.ndim == 1, f"spk must be 1D. Got: {spk.ndim}"
125
- if jnp.size(g_max) == 1:
126
- assert isinstance(n_post, int), f"n_post must be an integer when weight is homogenous. Got: {n_post}"
127
- # return jnp.full((n_post,), fill_value=jnp.sum(spk_vector) * weight)
128
- return jnp.ones((n_post,), dtype=g_max.dtype) * (jnp.sum(spk_vector) * g_max)
129
-
130
- if grad_mode == 'vjp':
131
- post = _cpu_event_linear_mv_vjp(g_max, spk_vector)
132
- elif grad_mode == 'jvp':
133
- post = _cpu_event_linear_mv_jvp(g_max, spk_vector)
134
- else:
135
- raise ValueError(f"Unsupported grad_mode: {grad_mode}")
136
- return post
137
-
138
- assert spk.ndim >= 1, f"spk must be at least 1D. Got: {spk.ndim}"
139
- assert g_max.ndim in [2, 0], f"weight must be 2D or 0D. Got: {g_max.ndim}"
140
-
141
- if spk.ndim == 1:
142
- post_data = mv(spk)
143
- else:
144
- shape = spk.shape[:-1]
145
- post_data = jax.vmap(mv)(u.math.reshape(spk, (-1, spk.shape[-1])))
146
- post_data = u.math.reshape(post_data, shape + post_data.shape[-1:])
147
- return u.maybe_decimal(u.Quantity(post_data, unit=unit))
148
-
149
-
150
- # --------------
151
- # Implementation
152
- # --------------
153
-
154
-
155
- def _cpu_event_linear_mv(g_max, spk) -> jax.Array:
156
- def scan_fn(post, i):
157
- sp = spk[i]
158
- if spk.dtype == jnp.bool_:
159
- post = jax.lax.cond(sp, lambda: post + g_max[i], lambda: post)
160
- else:
161
- post = jax.lax.cond(sp == 0., lambda: post, lambda: post + g_max[i] * sp)
162
- return post, None
163
-
164
- return jax.lax.scan(scan_fn, jnp.zeros(g_max.shape[1], dtype=g_max.dtype), np.arange(len(spk)))[0]
165
-
166
-
167
- # --------------
168
- # VJP
169
- # --------------
170
-
171
- def _cpu_event_linear_mv_fwd(g_max, spk):
172
- return _cpu_event_linear_mv(g_max, spk), (g_max, spk)
173
-
174
-
175
- def _cpu_event_linear_mv_bwd(res, ct):
176
- g_max, spk = res
177
-
178
- # ∂L/∂spk = ∂L/∂y * ∂y/∂spk
179
- ct_spk = jnp.matmul(g_max, ct)
180
-
181
- # ∂L/∂w = ∂L/∂y * ∂y/∂w
182
- def map_fn(sp):
183
- if spk.dtype == jnp.bool_:
184
- d_gmax = jax.lax.cond(sp, lambda: ct, lambda: jnp.zeros_like(ct))
185
- else:
186
- d_gmax = jax.lax.cond(sp == 0., lambda: jnp.zeros_like(ct), lambda: ct * sp)
187
- return d_gmax
188
-
189
- ct_gmax = jax.vmap(map_fn)(spk)
190
- return ct_gmax, ct_spk
191
-
192
-
193
- _cpu_event_linear_mv_vjp = jax.custom_vjp(_cpu_event_linear_mv)
194
- _cpu_event_linear_mv_vjp.defvjp(_cpu_event_linear_mv_fwd, _cpu_event_linear_mv_bwd)
195
-
196
-
197
- # --------------
198
- # JVP
199
- # --------------
200
-
201
-
202
- def _cpu_event_linear_mv_jvp_rule(primals, tangents):
203
- # forward pass
204
- g_max, spk = primals
205
- y = _cpu_event_linear_mv(g_max, spk)
206
-
207
- # forward gradients
208
- gmax_dot, spk_dot = tangents
209
-
210
- # ∂y/∂gmax
211
- dgmax = _cpu_event_linear_mv(gmax_dot, spk)
212
-
213
- # ∂y/∂gspk
214
- dspk = spk_dot @ g_max
215
- return y, dgmax + dspk
216
-
217
-
218
- _cpu_event_linear_mv_jvp = jax.custom_jvp(_cpu_event_linear_mv)
219
- _cpu_event_linear_mv_jvp.defjvp(_cpu_event_linear_mv_jvp_rule)
File without changes