brainstate 0.1.0.post20250104__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +77 -44
- brainstate/_state_test.py +0 -17
- brainstate/augment/_eval_shape.py +9 -10
- brainstate/augment/_eval_shape_test.py +1 -1
- brainstate/augment/_mapping.py +265 -277
- brainstate/augment/_mapping_test.py +147 -175
- brainstate/compile/_ad_checkpoint.py +6 -4
- brainstate/compile/_error_if_test.py +1 -0
- brainstate/compile/_jit.py +37 -28
- brainstate/compile/_loop_collect_return.py +8 -5
- brainstate/compile/_loop_no_collection.py +2 -0
- brainstate/compile/_make_jaxpr.py +7 -3
- brainstate/compile/_make_jaxpr_test.py +2 -1
- brainstate/compile/_progress_bar.py +68 -40
- brainstate/compile/_unvmap.py +6 -2
- brainstate/environ.py +28 -18
- brainstate/environ_test.py +4 -0
- brainstate/event/__init__.py +0 -2
- brainstate/event/_csr.py +266 -23
- brainstate/event/_csr_test.py +187 -0
- brainstate/event/_fixedprob_mv.py +4 -2
- brainstate/event/_fixedprob_mv_test.py +2 -1
- brainstate/event/_xla_custom_op.py +16 -5
- brainstate/graph/__init__.py +8 -12
- brainstate/graph/_graph_node.py +1 -23
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/graph/_graph_operation_test.py +0 -159
- brainstate/nn/_dyn_impl/_inputs.py +124 -39
- brainstate/nn/_interaction/_conv.py +4 -2
- brainstate/nn/_interaction/_linear.py +84 -10
- brainstate/random/_rand_funs.py +9 -2
- brainstate/random/_rand_seed.py +12 -2
- brainstate/random/_rand_state.py +50 -179
- brainstate/surrogate.py +5 -1
- brainstate/util/__init__.py +0 -4
- brainstate/util/_caller.py +1 -1
- brainstate/util/_dict.py +4 -1
- brainstate/util/_filter.py +1 -1
- brainstate/util/_pretty_repr.py +1 -1
- brainstate/util/_struct.py +1 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +46 -52
- brainstate/event/_csr_mv_test.py +0 -118
- brainstate/graph/_graph_context.py +0 -443
- brainstate/graph/_graph_context_test.py +0 -65
- brainstate/graph/_graph_convert.py +0 -246
- brainstate/util/_tracers.py +0 -68
- brainstate/util/_visualization.py +0 -47
- /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
@@ -1,54 +1,53 @@
|
|
1
1
|
brainstate/__init__.py,sha256=A-QKdOvSalsCMxgk80Iz6_xMiUin6con6JaONHfciSY,1526
|
2
|
-
brainstate/_state.py,sha256=
|
3
|
-
brainstate/_state_test.py,sha256=
|
2
|
+
brainstate/_state.py,sha256=GZ46liHZSHbAHQEuELvOeoJ27P9xiZDz06G2AASjAjA,29142
|
3
|
+
brainstate/_state_test.py,sha256=rJUFRSXEqrrl4qANRewY9mnDlzSbtHwBIGeZ0ku-8Dg,1650
|
4
4
|
brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
|
5
|
-
brainstate/environ.py,sha256=
|
6
|
-
brainstate/environ_test.py,sha256=
|
5
|
+
brainstate/environ.py,sha256=PZnVFWPioUBuWmwCO8wwCKrHQfP3BR-5lYPRl5i5GDA,17698
|
6
|
+
brainstate/environ_test.py,sha256=QD6sPCKNtqemVCGwkdImjMazatrvvLr6YeAVcfUnVVY,2045
|
7
7
|
brainstate/mixin.py,sha256=g7uVUwZphZWsNs9pb48ozG2cDGaj0hs0g3lq8tDk-Sg,11310
|
8
8
|
brainstate/mixin_test.py,sha256=Oq_0fwC9vpXDN4t4dTBhWzLdFDNlcYsrcip14F1yECI,3079
|
9
|
-
brainstate/surrogate.py,sha256=
|
9
|
+
brainstate/surrogate.py,sha256=t4SzVwUVMAPtC-O1vFbuE9F4265wgAv7ud77ufIJsuk,48464
|
10
10
|
brainstate/transform.py,sha256=cxbymTlJ6uHvJWEEYXzFUkAySs_TbUTHakt0NQgWJ3s,808
|
11
11
|
brainstate/typing.py,sha256=Qh-LBzm6oG4rSXv4V5qB8SNYcoOR7bASoK_iQxnlafk,10467
|
12
12
|
brainstate/augment/__init__.py,sha256=BtXIBel7GbttmfBX6grxOxl0IiOJxLEa7qCGAXumamE,1286
|
13
13
|
brainstate/augment/_autograd.py,sha256=o9ivoEY7BmtdM1XmzdMmeRXpj6Tvn5xNB8LSGp2HKC8,25238
|
14
14
|
brainstate/augment/_autograd_test.py,sha256=S2eEgrwTzdSi3u2nKE3u37WSThosLwx1WCP9ptJAGKo,44060
|
15
|
-
brainstate/augment/_eval_shape.py,sha256=
|
16
|
-
brainstate/augment/_eval_shape_test.py,sha256=
|
17
|
-
brainstate/augment/_mapping.py,sha256=
|
18
|
-
brainstate/augment/_mapping_test.py,sha256=
|
15
|
+
brainstate/augment/_eval_shape.py,sha256=ObCgsZ704kLduB1dbjJZh5nVQYEkLR5ebK74V5NV42k,3892
|
16
|
+
brainstate/augment/_eval_shape_test.py,sha256=LFOJx7CWltmRLXdGY175UebLwtEMz2CzJ_gLqMZsJTw,1393
|
17
|
+
brainstate/augment/_mapping.py,sha256=nU6Y7fSnYXyQSILXU2QT-O73Fm3pnwOmgUoDaHqjve8,21544
|
18
|
+
brainstate/augment/_mapping_test.py,sha256=_KFhE3CXItwpbZ1gJfrDu3yUtX0YbfPUuHJG_G_BXEs,8963
|
19
19
|
brainstate/augment/_random.py,sha256=rkB4w4BkKsz9p8lTk31kVHvlVPJSvtGk8REn936KI_4,3071
|
20
20
|
brainstate/compile/__init__.py,sha256=qZZIYoyEl51IFkFu-Hb-bP3PAEHo94HlTDf57P2ze08,1858
|
21
|
-
brainstate/compile/_ad_checkpoint.py,sha256=
|
21
|
+
brainstate/compile/_ad_checkpoint.py,sha256=K6I4vnznDsqC9cUeCnez9UdV9r_toGA3zHezoHLA6mI,9377
|
22
22
|
brainstate/compile/_ad_checkpoint_test.py,sha256=R1I76nG4zIqb6g3M_VxWts7rUC1OHJCjtQhPkcbXodk,1746
|
23
23
|
brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nbYcF0,10113
|
24
24
|
brainstate/compile/_conditions_test.py,sha256=s9LF6h9LvigvgxUIugTqvgCHBIU8TXS1Ar1OlIxXfrw,8389
|
25
25
|
brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvrWs,2689
|
26
|
-
brainstate/compile/_error_if_test.py,sha256=
|
27
|
-
brainstate/compile/_jit.py,sha256=
|
26
|
+
brainstate/compile/_error_if_test.py,sha256=OdJG483IIdOrCHxtHd49OHfOxCSnSkk7GdAUOzSt8bE,2044
|
27
|
+
brainstate/compile/_jit.py,sha256=itAWENKfJvnlaWl_uSy8lHTK8K1in89F_ZXXwp-EGRM,13944
|
28
28
|
brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
|
29
|
-
brainstate/compile/_loop_collect_return.py,sha256=
|
29
|
+
brainstate/compile/_loop_collect_return.py,sha256=TrKBZhtQecTtuiVz_HOeyepde-znzjlyk0Te53-AvOE,23492
|
30
30
|
brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
|
31
|
-
brainstate/compile/_loop_no_collection.py,sha256=
|
31
|
+
brainstate/compile/_loop_no_collection.py,sha256=qto2__Zt2PJntkjB9AXEgraGLvNUJS483BhCXjJyqv0,7495
|
32
32
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
33
|
-
brainstate/compile/_make_jaxpr.py,sha256=
|
34
|
-
brainstate/compile/_make_jaxpr_test.py,sha256=
|
35
|
-
brainstate/compile/_progress_bar.py,sha256=
|
36
|
-
brainstate/compile/_unvmap.py,sha256=
|
33
|
+
brainstate/compile/_make_jaxpr.py,sha256=DQf_80w3p0wi2Gb9P6_tLMJ0Oadgyr_jWkVjus0MSjw,33205
|
34
|
+
brainstate/compile/_make_jaxpr_test.py,sha256=3XaX8LUuG6UjolcD83qDVo5odf8FCDppdr9Q6V0NBs4,4303
|
35
|
+
brainstate/compile/_progress_bar.py,sha256=0oVlZ4kW_ZMciJjOR_ebj3PNe_XkCMkoQpv-HUUdoF0,5554
|
36
|
+
brainstate/compile/_unvmap.py,sha256=EY4rbqCzzPOiaRwpWTiyBwb5dVkYFnacHhBZUZObxPI,4255
|
37
37
|
brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
|
38
|
-
brainstate/event/__init__.py,sha256=
|
39
|
-
brainstate/event/_csr.py,sha256=
|
38
|
+
brainstate/event/__init__.py,sha256=gSEem-1oTHgy99Mjm3uumTXVd93tLVl0c4dUgRpoifk,895
|
39
|
+
brainstate/event/_csr.py,sha256=PYKw8CGNgQ24MxQDoeBZTrPuC7Z-GetXQld9KiTbNYw,40063
|
40
|
+
brainstate/event/_csr_benchmark.py,sha256=xrj2DSWzw0pUHAE1jRBeSRhMW7ogXvDHEdeaZGioNE4,702
|
40
41
|
brainstate/event/_csr_mv.py,sha256=HStHvK3KyEMfLsIUslZjgbdU6OsD1yKGrzQOzBXG36M,10266
|
41
|
-
brainstate/event/
|
42
|
-
brainstate/event/
|
43
|
-
brainstate/event/_csr_test.py,sha256=v59rnwTy8jrvqjdGzN75kvLg0wLBmRbthaVRKY2f0Uw,2945
|
44
|
-
brainstate/event/_fixedprob_mv.py,sha256=HP5uyFwue5ZNhsU71ZedMQ-Kp5-st89aLKGNhCBmBRA,25457
|
42
|
+
brainstate/event/_csr_test.py,sha256=_iXwUFq90GU7npVOUnlI4NA27RJ8zyCZBxe7NDH803o,9533
|
43
|
+
brainstate/event/_fixedprob_mv.py,sha256=nR3lhd87t1Vge435QHnFuDp-UBbWoW0Qk1kbsjRHQyc,25541
|
45
44
|
brainstate/event/_fixedprob_mv_benchmark.py,sha256=_F_8fH5MNMJZHeSqnq9DYMI9OgYr6JIxBKjbsgeWRv4,4720
|
46
|
-
brainstate/event/_fixedprob_mv_test.py,sha256=
|
45
|
+
brainstate/event/_fixedprob_mv_test.py,sha256=pVEarvGbqTjnAbxgMVRTAhkyYbvDnlyCJdeOdDD927w,4283
|
47
46
|
brainstate/event/_linear_mv.py,sha256=O5qbY31GNV1qEDrZ5kvPbA8Ae-bY5JpUgGtqDFNAeV0,11794
|
48
47
|
brainstate/event/_linear_mv_benckmark.py,sha256=hu0WqYMIa3jMoH7Fq9dgxcBjjXGFhghPx9vztyCo1KY,2411
|
49
48
|
brainstate/event/_linear_mv_test.py,sha256=V9w41ZP2vu95CyCdCkm-j9Eftqs2kqmeBY809N1-syY,3736
|
50
49
|
brainstate/event/_misc.py,sha256=8IpPooXjF2m0-tuo3pGHqThq2yLSNmYziy_zdurZ3NI,1040
|
51
|
-
brainstate/event/_xla_custom_op.py,sha256=
|
50
|
+
brainstate/event/_xla_custom_op.py,sha256=wF_nKgLUv1IGd8OY89MYqIvyZITl8UcrVysJWFugJxY,11093
|
52
51
|
brainstate/event/_xla_custom_op_test.py,sha256=rnkGMleXzLfJj4y5QqwfBvCCLTAHe_uabwBDniY-URM,1745
|
53
52
|
brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
|
54
53
|
brainstate/functional/_activations.py,sha256=S0Ok7sq5FTbmJWSejpOCHo1jpKX0gYOLy_TO2IUXM8s,21726
|
@@ -56,14 +55,11 @@ brainstate/functional/_activations_test.py,sha256=T___RlSrIfXwlkw8dg5A9EZMTZGDzv
|
|
56
55
|
brainstate/functional/_normalization.py,sha256=i2EV7hSsqcNdcYRX2wAxjq8doHwyN9eNJTGTaPt03xE,2605
|
57
56
|
brainstate/functional/_others.py,sha256=_u_Ys-LiLzDAP4zJggVwaVvirgoS3jvhXMREoS6JOkM,1737
|
58
57
|
brainstate/functional/_spikes.py,sha256=QY-2ayJkgkGELcq-bftPEaf_hJptVf_SP3fY36QvlZc,2678
|
59
|
-
brainstate/graph/__init__.py,sha256=
|
60
|
-
brainstate/graph/
|
61
|
-
brainstate/graph/_graph_context_test.py,sha256=IYpjqbXwSFF65XL0ZbdPeC1jYyEHLpQVrhuFeJXH4GM,2409
|
62
|
-
brainstate/graph/_graph_convert.py,sha256=llSREtGQrIggkD0wmxUbYKuSveLW4ihDZME6Ab-mRTQ,9147
|
63
|
-
brainstate/graph/_graph_node.py,sha256=mmZ0jhZev8ReNJhVLgWqYJEedEDtJHxhwxRv4ytQVNo,9268
|
58
|
+
brainstate/graph/__init__.py,sha256=fyvQMlAUY3QYTzvDzz5TDoWS2XQwZ6P3ic6BtysZyHM,1026
|
59
|
+
brainstate/graph/_graph_node.py,sha256=swAokZLKswSTaq2WEhyLIs38sy_67C6maHI6T3e1hvY,8339
|
64
60
|
brainstate/graph/_graph_node_test.py,sha256=BFGfdzZFDHI0XK7hHotSVWKt3em1taGvn8FHF9NCXx8,2702
|
65
|
-
brainstate/graph/_graph_operation.py,sha256=
|
66
|
-
brainstate/graph/_graph_operation_test.py,sha256=
|
61
|
+
brainstate/graph/_graph_operation.py,sha256=cIwGo3ICgtce2fmdn917r81evMFjJIKeW9doaQK4DD8,64111
|
62
|
+
brainstate/graph/_graph_operation_test.py,sha256=zjvpKjQAFWtw8YZuqOk_jmlZNb_-E8oPyNx57dyc8jI,18556
|
67
63
|
brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
|
68
64
|
brainstate/init/_base.py,sha256=B_NLS9aKNrvuj5NAlSgBbQTVev7IRvzcx8vH0J-Gq2w,1671
|
69
65
|
brainstate/init/_generic.py,sha256=sGOvd_atpxLWqqZKobTfAiMiYRnDC19PBNHdQy_igFM,8028
|
@@ -83,7 +79,7 @@ brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2
|
|
83
79
|
brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
|
84
80
|
brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=MsbPyaiDyjungyzuK2b3exRGaMpZgmsmmNHNLjgxQKw,15269
|
85
81
|
brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
|
86
|
-
brainstate/nn/_dyn_impl/_inputs.py,sha256=
|
82
|
+
brainstate/nn/_dyn_impl/_inputs.py,sha256=UNoGxKIKXwPnhelljDowqAWlV6ds7aBBkEbvdy2oDI4,11302
|
87
83
|
brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
|
88
84
|
brainstate/nn/_dyn_impl/_rate_rnns.py,sha256=dz_yT_6hJVhKulcjIARbGtmMzZqISws96CtBc6o5GOo,14768
|
89
85
|
brainstate/nn/_dyn_impl/_rate_rnns_test.py,sha256=gNgtr-a4ZiU1XF9wFG1HiJ9fLosfWchVR9Zn1x39xt4,2452
|
@@ -102,10 +98,10 @@ brainstate/nn/_elementwise/_dropout_test.py,sha256=ZzNvjFf46NpKWGBIcT6O0lKOBGpxO
|
|
102
98
|
brainstate/nn/_elementwise/_elementwise.py,sha256=om-KpwDTk5yFG5KBYXXHquRLV7s28_FJjk-omvyMyvQ,33342
|
103
99
|
brainstate/nn/_elementwise/_elementwise_test.py,sha256=SZI9jB39sZ5SO1dpWGW-PhodthwN0GU9FY1nqf2fWcs,5341
|
104
100
|
brainstate/nn/_interaction/__init__.py,sha256=TTY_SeNrdx4VnUSw6vdyl02OHdS9Qs15cWBp6kjsyNQ,1289
|
105
|
-
brainstate/nn/_interaction/_conv.py,sha256=
|
101
|
+
brainstate/nn/_interaction/_conv.py,sha256=lwyxTVsJVPiKlZcgB6iqE64aX7AOJzplDSj4y6-m18o,18592
|
106
102
|
brainstate/nn/_interaction/_conv_test.py,sha256=fHXRFYnDghFiKre63RqMwIE_gbPKdK34UPhKOz-J3qU,8695
|
107
103
|
brainstate/nn/_interaction/_embedding.py,sha256=iK0I1ExKWFa_QzV9UDGj32Ljsmdr1g_LlAtMcusebxU,2187
|
108
|
-
brainstate/nn/_interaction/_linear.py,sha256=
|
104
|
+
brainstate/nn/_interaction/_linear.py,sha256=EnkOk1oE79rvRIjU6HBllxUpVOEcQQCj4vtavo9AJjI,14767
|
109
105
|
brainstate/nn/_interaction/_linear_test.py,sha256=QfCR8SBBed9OnSY-AmQ0kJqoggDA3Xem0dRJ0BusxLU,2872
|
110
106
|
brainstate/nn/_interaction/_normalizations.py,sha256=7YDzkmO_iqd70fH_wawb60Bu8eGOdvZq23emP-b68Hc,37440
|
111
107
|
brainstate/nn/_interaction/_normalizations_test.py,sha256=2p1Jf8nA999VYGWbvOZfKYlKk6UmL0vaEB76xkXxkXw,2438
|
@@ -119,26 +115,24 @@ brainstate/optim/_optax_optimizer.py,sha256=SuXV_xUBfhOw1_C2J5TIpy3dXDtI9VJFaSML
|
|
119
115
|
brainstate/optim/_optax_optimizer_test.py,sha256=DAomE8Eu3dn4gh1S3EZ_u4pW4rhcl16vWPbnDcN3Rs4,1762
|
120
116
|
brainstate/optim/_sgd_optimizer.py,sha256=NVKYhGcw2D1ksNWUIXZcj-74LUaan8XL3EERk-EHMRI,46008
|
121
117
|
brainstate/random/__init__.py,sha256=c5q-RC3grRIjx-HBb2IhKZpi_xzbFmUUxzRAzqfREic,1045
|
122
|
-
brainstate/random/_rand_funs.py,sha256=
|
118
|
+
brainstate/random/_rand_funs.py,sha256=WaelvEpeQb6Vuqt4eNgsAtd7GI8BqgEdVYbXgtCOd54,137682
|
123
119
|
brainstate/random/_rand_funs_test.py,sha256=abO5lSoPBgBcg6ecFE1qnCg98__QGa68GSYC5pQW5QI,19438
|
124
|
-
brainstate/random/_rand_seed.py,sha256=
|
120
|
+
brainstate/random/_rand_seed.py,sha256=MHA9znbdJW9ujx73onDRrAOI684_0FmGfqczBsSXYQg,5985
|
125
121
|
brainstate/random/_rand_seed_test.py,sha256=Qibcs-ZqCvj1LuucmQ8H00B_HBNhf2f6un0aUdNZNTw,1518
|
126
|
-
brainstate/random/_rand_state.py,sha256=
|
122
|
+
brainstate/random/_rand_state.py,sha256=nuoQ8GU1MfJPRNN-ZmRQsggVjoyPhaEdZmwM7_4-Q3c,55206
|
127
123
|
brainstate/random/_random_for_unit.py,sha256=kGp4EUX19MXJ9Govoivbg8N0bddqOldKEI2h_TbdONY,2057
|
128
|
-
brainstate/util/__init__.py,sha256
|
129
|
-
brainstate/util/_caller.py,sha256=
|
130
|
-
brainstate/util/_dict.py,sha256=
|
124
|
+
brainstate/util/__init__.py,sha256=-FWEuSKXG3mWxYphGFAy3UEuVe39lFs1GruluzdXDoI,1502
|
125
|
+
brainstate/util/_caller.py,sha256=T3bzu7-09r-6EOrU6Muca_aMXSQua_X2lXjEqb-w39w,2782
|
126
|
+
brainstate/util/_dict.py,sha256=Yapug-_RZQYjvd8cZ3v90_MX7rUYJDBzBnZJT6a0NXY,26178
|
131
127
|
brainstate/util/_dict_test.py,sha256=Dn0TdjX6wLBXaTD4jfYTu6cKfFHwKSxi4_3bX7kB_IA,5621
|
132
128
|
brainstate/util/_error.py,sha256=eyZ8PGFixqe2K5OEfjSDzI-2tU0ieYQoUpBP7yStlPQ,878
|
133
|
-
brainstate/util/_filter.py,sha256=
|
129
|
+
brainstate/util/_filter.py,sha256=1-bvFHdjeehvXeHTrCEp8xr25lopKe8d3XZGCNegq0s,4970
|
134
130
|
brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14762
|
135
|
-
brainstate/util/_pretty_repr.py,sha256=
|
131
|
+
brainstate/util/_pretty_repr.py,sha256=bDpU4gbkS4B8cXBkiN8kBQNmruxiJzDRF-eIqzyeYnM,5716
|
136
132
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
137
|
-
brainstate/util/_struct.py,sha256=
|
138
|
-
brainstate/
|
139
|
-
brainstate/
|
140
|
-
brainstate-0.1.0.
|
141
|
-
brainstate-0.1.0.
|
142
|
-
brainstate-0.1.0.
|
143
|
-
brainstate-0.1.0.post20250104.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
144
|
-
brainstate-0.1.0.post20250104.dist-info/RECORD,,
|
133
|
+
brainstate/util/_struct.py,sha256=KMMHcshOM20gYhSahNzWLxsTt-Rt3AeX3Uz26-rP9vI,17619
|
134
|
+
brainstate-0.1.0.post20250120.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
135
|
+
brainstate-0.1.0.post20250120.dist-info/METADATA,sha256=vUyr4XjiyAW68waFKMray9EEFHTqjqRp5GlqAG8LsKY,3585
|
136
|
+
brainstate-0.1.0.post20250120.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
137
|
+
brainstate-0.1.0.post20250120.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
138
|
+
brainstate-0.1.0.post20250120.dist-info/RECORD,,
|
brainstate/event/_csr_mv_test.py
DELETED
@@ -1,118 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
|
-
import jax.numpy
|
19
|
-
import jax.numpy as jnp
|
20
|
-
import numpy as np
|
21
|
-
from absl.testing import parameterized
|
22
|
-
|
23
|
-
import brainstate as bst
|
24
|
-
|
25
|
-
|
26
|
-
def _get_csr(n_pre, n_post, prob):
|
27
|
-
n_conn = int(n_post * prob)
|
28
|
-
indptr = np.arange(n_pre + 1) * n_conn
|
29
|
-
indices = np.random.randint(0, n_post, (n_pre * n_conn,))
|
30
|
-
return indptr, indices
|
31
|
-
|
32
|
-
|
33
|
-
def true_fn(x, w, indices, indptr, n_out):
|
34
|
-
homo_w = jnp.size(w) == 1
|
35
|
-
|
36
|
-
post = jnp.zeros((n_out,))
|
37
|
-
for i_pre in range(x.shape[0]):
|
38
|
-
ids = indices[indptr[i_pre]: indptr[i_pre + 1]]
|
39
|
-
post = post.at[ids].add(w * x[i_pre] if homo_w else w[indptr[i_pre]: indptr[i_pre + 1]] * x[i_pre])
|
40
|
-
return post
|
41
|
-
|
42
|
-
|
43
|
-
# class TestFixedProbCSR(parameterized.TestCase):
|
44
|
-
# @parameterized.product(
|
45
|
-
# homo_w=[True, False],
|
46
|
-
# )
|
47
|
-
# def test1(self, homo_w):
|
48
|
-
# x = bst.random.rand(20) < 0.1
|
49
|
-
# indptr, indices = _get_csr(20, 40, 0.1)
|
50
|
-
# m = bst.event.CSRLinear(20, 40, indptr, indices, 1.5 if homo_w else bst.init.Normal())
|
51
|
-
# y = m(x)
|
52
|
-
# y2 = true_fn(x, m.weight.value, indices, indptr, 40)
|
53
|
-
# self.assertTrue(jnp.allclose(y, y2))
|
54
|
-
#
|
55
|
-
# @parameterized.product(
|
56
|
-
# bool_x=[True, False],
|
57
|
-
# homo_w=[True, False]
|
58
|
-
# )
|
59
|
-
# def test_vjp(self, bool_x, homo_w):
|
60
|
-
# n_in = 20
|
61
|
-
# n_out = 30
|
62
|
-
# if bool_x:
|
63
|
-
# x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
|
64
|
-
# else:
|
65
|
-
# x = bst.random.rand(n_in)
|
66
|
-
#
|
67
|
-
# indptr, indices = _get_csr(n_in, n_out, 0.1)
|
68
|
-
# fn = bst.event.CSRLinear(n_in, n_out, indptr, indices, 1.5 if homo_w else bst.init.Normal())
|
69
|
-
# w = fn.weight.value
|
70
|
-
#
|
71
|
-
# def f(x, w):
|
72
|
-
# fn.weight.value = w
|
73
|
-
# return fn(x).sum()
|
74
|
-
#
|
75
|
-
# r = jax.grad(f, argnums=(0, 1))(x, w)
|
76
|
-
#
|
77
|
-
# # -------------------
|
78
|
-
# # TRUE gradients
|
79
|
-
#
|
80
|
-
# def f2(x, w):
|
81
|
-
# return true_fn(x, w, indices, indptr, n_out).sum()
|
82
|
-
#
|
83
|
-
# r2 = jax.grad(f2, argnums=(0, 1))(x, w)
|
84
|
-
# self.assertTrue(jnp.allclose(r[0], r2[0]))
|
85
|
-
# self.assertTrue(jnp.allclose(r[1], r2[1]))
|
86
|
-
#
|
87
|
-
# @parameterized.product(
|
88
|
-
# bool_x=[True, False],
|
89
|
-
# homo_w=[True, False]
|
90
|
-
# )
|
91
|
-
# def test_jvp(self, bool_x, homo_w):
|
92
|
-
# n_in = 20
|
93
|
-
# n_out = 30
|
94
|
-
# if bool_x:
|
95
|
-
# x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
|
96
|
-
# else:
|
97
|
-
# x = bst.random.rand(n_in)
|
98
|
-
#
|
99
|
-
# indptr, indices = _get_csr(n_in, n_out, 0.1)
|
100
|
-
# fn = bst.event.CSRLinear(n_in, n_out, indptr, indices,
|
101
|
-
# 1.5 if homo_w else bst.init.Normal(), grad_mode='jvp')
|
102
|
-
# w = fn.weight.value
|
103
|
-
#
|
104
|
-
# def f(x, w):
|
105
|
-
# fn.weight.value = w
|
106
|
-
# return fn(x)
|
107
|
-
#
|
108
|
-
# o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
|
109
|
-
#
|
110
|
-
# # -------------------
|
111
|
-
# # TRUE gradients
|
112
|
-
#
|
113
|
-
# def f2(x, w):
|
114
|
-
# return true_fn(x, w, indices, indptr, n_out)
|
115
|
-
#
|
116
|
-
# o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
|
117
|
-
# self.assertTrue(jnp.allclose(r1, r2))
|
118
|
-
# self.assertTrue(jnp.allclose(o1, o2))
|