brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,111 @@
|
|
1
|
+
brainstate/__init__.py,sha256=bmCIZG6xMiVlZ5qWk-d5qpkqdnrWxFhLePtpu3ff5WI,6024
|
2
|
+
brainstate/_compatible_import.py,sha256=7thV_2F0FD5AF2DETjBfmtNb_2ZQzki8NxFgC62frg0,11037
|
3
|
+
brainstate/_compatible_import_test.py,sha256=6ka26Sa_Kk6F-Ar1HR6UaKJTHquXcUCWglgXBUOovcg,22762
|
4
|
+
brainstate/_deprecation.py,sha256=gSh36_TWLBgQAo0gNfOzscV9ssa26k3te9y25BG6O2w,8381
|
5
|
+
brainstate/_deprecation_test.py,sha256=5_tJ9JDhG79zLwe5MuBLsqcgl7gi-oXPgDTCujrPmz0,88325
|
6
|
+
brainstate/_error.py,sha256=6A5ILy17ZMMZIjS8LkajTZBDRnwv_Qait5x__h2Levo,1522
|
7
|
+
brainstate/_state.py,sha256=YCYKrX2xCTojIY15vtQl8WWjtJ0HNMg8y5LCUk1d_3Q,58096
|
8
|
+
brainstate/_state_test.py,sha256=wCoWTbvARVkhNokMLLBV6sEoHSE9OJRyuUVrNDcRiG8,1589
|
9
|
+
brainstate/_utils.py,sha256=cmUyO9ds1etrrpV4ucp1G8mDqE15g4ZtbivblH_cD9o,1613
|
10
|
+
brainstate/environ.py,sha256=BmQsvo1aZaMpckHXlJ45dZh9DUdnHHP9Q9JdLNCa9wA,42169
|
11
|
+
brainstate/environ_test.py,sha256=RdVmeP7irbk3_qNjwWoy-DSdpkTRpxqAABtNGZcbB2w,42418
|
12
|
+
brainstate/mixin.py,sha256=O58rGznYowz-hBQy7iKOhbrg3Dze7FAR8AU4CX8hUEI,44635
|
13
|
+
brainstate/mixin_test.py,sha256=6WmqJf34kT3Z5WaiCNDo3OV3ci0DIHCQp474zECDxEU,34718
|
14
|
+
brainstate/typing.py,sha256=pYiNI-9oHpH7HfjRKYxugK03KGiamCwweagMyO0rsi4,26301
|
15
|
+
brainstate/typing_test.py,sha256=2mmMW0uAzIo3_VXpT5Boq79BohxYkzBlHexBysFUGII,26240
|
16
|
+
brainstate/graph/__init__.py,sha256=kGVtHAnkiWR5MqDYQU0G3AobWnioGeDqjILA--RyDz8,846
|
17
|
+
brainstate/graph/_node.py,sha256=_XH8xx6_glsCK4KCsQnarACK8meyhCdfh3nWfUDko0k,6407
|
18
|
+
brainstate/graph/_node_test.py,sha256=sD2DS0AhDKOU5ZQm0cYz0llnJ6D60ftNfDpztk4i8cM,18687
|
19
|
+
brainstate/graph/_operation.py,sha256=n2HqfwPzG3f6QHish_7d8lNZRhHUyS4_YHTdPMGgdUk,54096
|
20
|
+
brainstate/graph/_operation_test.py,sha256=IVyrJh4io3sDgtrTEIAItGzNs2XEf7rO1rvI1r_KiII,39119
|
21
|
+
brainstate/nn/__init__.py,sha256=oEQ81xpWppPMUMCRalMhAhUU5faYs0aKcyydwjoxGVo,4759
|
22
|
+
brainstate/nn/_activations.py,sha256=6jHR67obYR1lpo-imVXmfd3m_NDyU0XZb8t-pVYDvUU,26917
|
23
|
+
brainstate/nn/_activations_test.py,sha256=Ikr8RYBaIpApVKUhY-XAWr6llEG7vYWS8YuqDHyTtBY,13438
|
24
|
+
brainstate/nn/_collective_ops.py,sha256=UditkuCy0f9ggkFwGUXyJDoFtDdVUIQU3xmD5YOxRKg,21280
|
25
|
+
brainstate/nn/_collective_ops_test.py,sha256=8mKQkfTjfwuO7DA1i_Yr4QD1yg2ZJgVpX7zhHFo9CuQ,25600
|
26
|
+
brainstate/nn/_common.py,sha256=UyJMJoVF9KfrToOX5Dbv-2s3CD49SsroLRfL17DLl4Q,7184
|
27
|
+
brainstate/nn/_common_test.py,sha256=NpKQpKPPU4I2nVpz1X_bOGC5R-RyCP2kYlyYqXWoKAE,5966
|
28
|
+
brainstate/nn/_conv.py,sha256=3cGToc5UoGN5jp4BUHlrK6md8O_0IcWMCBXFHBQg7nE,82106
|
29
|
+
brainstate/nn/_conv_test.py,sha256=65FlrteUxLQb8ckUUJFhaPZDsQZTlGVAp0HUGUJtt1M,30173
|
30
|
+
brainstate/nn/_delay.py,sha256=0brSPlKOC7_nNWr1zcdjXNHJFw_I9KUX7PGyVWXENTk,22331
|
31
|
+
brainstate/nn/_delay_test.py,sha256=FzBb8vXfse8HEcEid83hpa6aag6oj90mtcHYsDs0DOE,10376
|
32
|
+
brainstate/nn/_dropout.py,sha256=UotjW0PQO4gypfhtSqzkR4UVGkY0kNBncEB_poGL7Sc,22555
|
33
|
+
brainstate/nn/_dropout_test.py,sha256=5LEN_G6t2TZAbmvJdZ75h1nyKZTHHiNT_i1W-im5wo8,21461
|
34
|
+
brainstate/nn/_dynamics.py,sha256=V7GA1enppLRRDJCw146RI_CdwpVyRgHUYH9fu2-PJOA,44181
|
35
|
+
brainstate/nn/_dynamics_test.py,sha256=G1VkrWgKK4CXtcKJmevxiQQ9DVeyf1zcdMsAWyCN0Jw,2356
|
36
|
+
brainstate/nn/_elementwise.py,sha256=4kKzrbKn5luwnpY8n7IeaMOtBVVie8oPHEc64hSn8-w,34858
|
37
|
+
brainstate/nn/_elementwise_test.py,sha256=sbWlUyTB8oiu3PRHObTvmUaob99EjIGYE6k6bEpg6K8,27296
|
38
|
+
brainstate/nn/_embedding.py,sha256=eMwbz9udm6WGVEGeHwhvCr7sJA6kMFyaZsli7lL30eg,14942
|
39
|
+
brainstate/nn/_embedding_test.py,sha256=Gc0y6gHMEagaDrBJoAYQZMDTd47TYQNVrenWUwLWK_w,6242
|
40
|
+
brainstate/nn/_event_fixedprob.py,sha256=ZEnIyjDksxtUWWG5GXLcF-RHR1S33DZTw-rN2lHKs0g,9395
|
41
|
+
brainstate/nn/_event_fixedprob_test.py,sha256=rvTKxEzKwvctQc8-AxXjJ4p4D-if1va2O8KqQYW5nxY,3836
|
42
|
+
brainstate/nn/_event_linear.py,sha256=d0J54Sf9zBl926BzXoy4Oc1p96h9veU3f8YhWrZLRPk,2554
|
43
|
+
brainstate/nn/_event_linear_test.py,sha256=qzcGplDIwxTnZOs4JzD5GX_oNBYtcYFNaf3GpWo8pZY,3765
|
44
|
+
brainstate/nn/_exp_euler.py,sha256=W5ofUPe6UK8NXO917PY9_nNXOI9chSHgBmnDHMVXHoo,8635
|
45
|
+
brainstate/nn/_exp_euler_test.py,sha256=21qomGOo96YLmlEQok2hCByAcRpmqREhSr3kxmhKOm8,13014
|
46
|
+
brainstate/nn/_linear.py,sha256=olbo9AmC35UQBxECyAQW06XeucsJfBAAMoSfwOW6c7s,24047
|
47
|
+
brainstate/nn/_linear_test.py,sha256=5fHx4v4_54dH3Bsyapl2cobvVtMEu3DFR-jDKLbkJFw,17876
|
48
|
+
brainstate/nn/_metrics.py,sha256=TgALwv6i9La4Dm1WAkWDWxvxr9rkd7CJLGFy2sOGQbQ,36481
|
49
|
+
brainstate/nn/_metrics_test.py,sha256=XZiRndchRgEH0X8zsHzg0fsHNMxj43mnd883QChfSik,24104
|
50
|
+
brainstate/nn/_module.py,sha256=to280ubWAP-HiCV7LknMBqOhet0UkH9Oh-PkLkual_g,12775
|
51
|
+
brainstate/nn/_module_test.py,sha256=znjB7FU5evJENQ1Pqw7ZlOGC5faQe4-4VpjW60H8UWI,1414
|
52
|
+
brainstate/nn/_normalizations.py,sha256=exdIn627ph5pVHdYx_4NIaK_f1xcLaBj-YQ0Ui22CsA,50185
|
53
|
+
brainstate/nn/_normalizations_test.py,sha256=y5n7aTaUHRkyAAjq4Oj8dfButJC3ehm4KdUB7226Bow,23350
|
54
|
+
brainstate/nn/_paddings.py,sha256=3u3dbRFtPSlIsLMBYZmHcYDQ6HFl0u_d-yP7ZoYcCrA,32415
|
55
|
+
brainstate/nn/_paddings_test.py,sha256=uY9CRexf9sM5V6AzTfzm8214y4IoGqEbahy92yaKsWM,27409
|
56
|
+
brainstate/nn/_poolings.py,sha256=aiLDTgtbDEse2OZlJONdU1CCsBjZtsWqsnL8ffTSGpk,86045
|
57
|
+
brainstate/nn/_poolings_test.py,sha256=cZ4lyutUY-Iti6rkGWjqWZ72fXF5QvwX9mki4CQ1xDs,34164
|
58
|
+
brainstate/nn/_rnns.py,sha256=FmuVsUwQCl7QbePmYokd-ZxKwVgDqa8FdCto_dNFDjk,32947
|
59
|
+
brainstate/nn/_rnns_test.py,sha256=gqPGz64i_QUEXi2fUBrIaBEPZjTs8hm8S6HyPbs3I6M,22112
|
60
|
+
brainstate/nn/_utils.py,sha256=VK-Se53e1q-Ip4AtMOZ3SUzYw8u2UllLJRLRtEFRCRE,7403
|
61
|
+
brainstate/nn/_utils_test.py,sha256=uim2SkfNHrBZzNDvN0WOK8qeZC1kaeOd-UQDvrn_M24,14266
|
62
|
+
brainstate/nn/init.py,sha256=7iLHrL-ZHpU-g5d0PlusaqmtkWO7X_KNr2eLx90oHrc,25656
|
63
|
+
brainstate/nn/init_test.py,sha256=bfby6kovvhbc7CCEaohtQawUQj3w3GvaMSDAdTiP2ps,6200
|
64
|
+
brainstate/random/__init__.py,sha256=2k5aI3GI_ftOtuWO9rosRUhyr5OjnMty82b-CTBkOUU,8383
|
65
|
+
brainstate/random/_rand_funs.py,sha256=hBkZYTlBNHBLdHTLBecT7sja7mVItMF7TV3ao26vfj0,135339
|
66
|
+
brainstate/random/_rand_funs_test.py,sha256=mZaSRtlXCwJP3YfF5GsFelR-YACg-uH1gHWG-coZT_k,23040
|
67
|
+
brainstate/random/_rand_seed.py,sha256=mzSe2lsyhN_0eXhkD8gpRKmSbReQfZF0pZA89cjiiHU,24927
|
68
|
+
brainstate/random/_rand_seed_test.py,sha256=Y2VCAkUzciDaCfYZWPe_Ewmi3MylK-WzfPA7TzorV8Q,1491
|
69
|
+
brainstate/random/_rand_state.py,sha256=_62wIrdZKKvmtjBErB0lRCuf1vYEskFyPd4Nhp_FKDU,53591
|
70
|
+
brainstate/random/_rand_state_test.py,sha256=0Y1kx7rvkACQiiiAHK0plePtLxYeYTzvWUGt3MimSGI,19227
|
71
|
+
brainstate/transform/__init__.py,sha256=CfzUYGQFt9hAp57ZvffUBBhLjudgzJ9n_aTaor57iOk,2241
|
72
|
+
brainstate/transform/_ad_checkpoint.py,sha256=4dcNCEQVV_CPMSkE32URERDMpQHbyfdGeLT_Nvhyd4o,6912
|
73
|
+
brainstate/transform/_ad_checkpoint_test.py,sha256=fPXBjDxsLHbL2mhIU3x_F5BpitkXLpgIsRCRgm2Us6w,1697
|
74
|
+
brainstate/transform/_autograd.py,sha256=4zGSYa9TMn6bqzPJNLfU9UZGZRyYxmwMXTKZWO4w3QQ,39991
|
75
|
+
brainstate/transform/_autograd_test.py,sha256=saWG1_k3cRXpsyQDzQkOLGvsF7IIxG9aGnjrf5B3HNk,44112
|
76
|
+
brainstate/transform/_conditions.py,sha256=IIu_V2f7R74saOoURGUxwzI1RpsnxkzpfB2fes9plOs,11399
|
77
|
+
brainstate/transform/_conditions_test.py,sha256=MEuqRq6IFmyORRDi0qWvNo4pWKFyc8aNrW1v9Saqxj0,8493
|
78
|
+
brainstate/transform/_error_if.py,sha256=e9tp3wT5p4bEyjn_Za_SrPNOG3OIoPBMIrvG2CsZzvw,2680
|
79
|
+
brainstate/transform/_error_if_test.py,sha256=yn-qcZ6lZUWciIif4fJOpdpKzJFAAdfgzZm6FfPeq7U,1848
|
80
|
+
brainstate/transform/_eval_shape.py,sha256=BNbjiFHUsk-qfENiZf1K8yX-x7eIIuAyWR0CjBIBr5w,5355
|
81
|
+
brainstate/transform/_eval_shape_test.py,sha256=4A2NdHcpksiGPf_UmPlMPgHJnes9ciiQO20ZJHVzA9g,1355
|
82
|
+
brainstate/transform/_jit.py,sha256=qYsL3Z9nAAW0UyQe_AyvBEuJvqAan6iw3lN49o1oC0A,15421
|
83
|
+
brainstate/transform/_jit_test.py,sha256=ecw54dGQYdJq2J94itPrXBQrSDNvCY_htUD7z7y4HUM,4013
|
84
|
+
brainstate/transform/_loop_collect_return.py,sha256=HhjC2gq6qzliw4ofP16VxdtR5hW-NmDZdeHxuiLdYGk,25899
|
85
|
+
brainstate/transform/_loop_collect_return_test.py,sha256=BVK-b3CuDtTXciRaA_8t4751N4taQOnIPNzAelSts-k,1753
|
86
|
+
brainstate/transform/_loop_no_collection.py,sha256=ArPpNemMh4jJsq_vUWPxuagCnxTlONN--3P_-44qYq8,10156
|
87
|
+
brainstate/transform/_loop_no_collection_test.py,sha256=3bRo9_Oaypbw3asEevrgTK0WksxDAIZKJgeaWpt7nl8,1371
|
88
|
+
brainstate/transform/_make_jaxpr.py,sha256=A9ivekPuzLUCoUcgSLm5hUuWPj5_1_ksqn2KvztcWhY,73327
|
89
|
+
brainstate/transform/_make_jaxpr_test.py,sha256=dFAnEQavYlEddex6h42ulRkXbV5TSGCDHFXWAvPsy7w,49128
|
90
|
+
brainstate/transform/_mapping.py,sha256=L9Q_1M_GXJA0ZCJtty9x01wViyefgeHyOLga-Jo2AE8,21570
|
91
|
+
brainstate/transform/_mapping_test.py,sha256=yGG-QSo3epqtBvGX8VxrNDKiC-d6TMN_jv_uKPad6eM,7647
|
92
|
+
brainstate/transform/_progress_bar.py,sha256=kZ-mI5hbUQXhqKFVyo0qeKG_LvrR9ZIar7WkXyOeET0,8961
|
93
|
+
brainstate/transform/_random.py,sha256=ZTH5Smx5SvFy6El7qk-ihoYE9WII6TJXcsC9Vm7VgjA,5259
|
94
|
+
brainstate/transform/_unvmap.py,sha256=cW6fjs5Iy1YBB6Nx2mxlM4IzV8U99bEX5QjT8rBRDho,6319
|
95
|
+
brainstate/transform/_util.py,sha256=IYqJj7oyAYzm_m3d9WEsUQRKDdVLQaAWpwM5O8PD4YQ,11304
|
96
|
+
brainstate/util/__init__.py,sha256=anHdG5BIsMqcBQy7gt8lErKInZv1wf2NOLJGUqltyAQ,1154
|
97
|
+
brainstate/util/_others.py,sha256=HSP3ynNv2ocxPz5omQl2rxMTOjGkQNNprsTOCEANCjA,30837
|
98
|
+
brainstate/util/_others_test.py,sha256=gEiUybMxtv11tsD0jge9Y0K8idkWkHVBlTZz1gKXNAU,30535
|
99
|
+
brainstate/util/_pretty_pytree.py,sha256=0fVwW8qHtKrpJU4q2tbn8usSKOlG8lBALn6E6crpf64,46815
|
100
|
+
brainstate/util/_pretty_pytree_test.py,sha256=6DiYX_Rwyvv3XcuQHsU-PuWXmvOqSmiVVN4diVvp76o,22362
|
101
|
+
brainstate/util/_pretty_repr.py,sha256=fafG6SIFoPjaWmQcTHwmnEqbVvcTCoN_lKiBKBk1QTQ,13958
|
102
|
+
brainstate/util/_pretty_repr_test.py,sha256=yfaANMfozlM5E3IZv-LBaNpyvSVB8ftRp0r7_oAr5yA,23122
|
103
|
+
brainstate/util/filter.py,sha256=wY_XUF3OhrXSV1bZkTcVhlEPba4HP1l9N5aRW2zgxqQ,27410
|
104
|
+
brainstate/util/filter_test.py,sha256=ZfrEeOc1yMHYzrcSR3p4jbZGj7c_tXC8VcPq7H13q8E,31653
|
105
|
+
brainstate/util/struct.py,sha256=LYPLGDGfPuw14hhx5k4rb8msSH3yZPdAu_0CvjxPWwE,24505
|
106
|
+
brainstate/util/struct_test.py,sha256=q_fWsUH1ON35DKjUUAMq6VtYglqTDyBvd6WMVGD89EI,16526
|
107
|
+
brainstate-0.2.0.dist-info/licenses/LICENSE,sha256=RJ40fox7u2in2H8wvIS5DsPGlNHaA7JI024thFUlaZE,11348
|
108
|
+
brainstate-0.2.0.dist-info/METADATA,sha256=GrI-RT31978SYlGNnRP9xodMOYl1g6mPpIiP1XdUaOY,4421
|
109
|
+
brainstate-0.2.0.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
|
110
|
+
brainstate-0.2.0.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
111
|
+
brainstate-0.2.0.dist-info/RECORD,,
|
brainstate/augment/__init__.py
DELETED
@@ -1,30 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""
|
17
|
-
This module includes transformations for augmenting the functionalities of JAX code.
|
18
|
-
"""
|
19
|
-
|
20
|
-
from ._autograd import GradientTransform, grad, vector_grad, hessian, jacobian, jacrev, jacfwd
|
21
|
-
from ._eval_shape import abstract_init
|
22
|
-
from ._mapping import vmap, pmap, map, vmap_new_states
|
23
|
-
from ._random import restore_rngs
|
24
|
-
|
25
|
-
__all__ = [
|
26
|
-
'GradientTransform', 'grad', 'vector_grad', 'hessian', 'jacobian', 'jacrev', 'jacfwd',
|
27
|
-
'abstract_init',
|
28
|
-
'vmap', 'pmap', 'map', 'vmap_new_states',
|
29
|
-
'restore_rngs',
|
30
|
-
]
|
@@ -1,99 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import functools
|
17
|
-
from typing import Any, TypeVar, Callable, Sequence, Union
|
18
|
-
|
19
|
-
import jax
|
20
|
-
|
21
|
-
from brainstate import random
|
22
|
-
from brainstate.graph import Node, flatten, unflatten
|
23
|
-
from ._random import restore_rngs
|
24
|
-
|
25
|
-
__all__ = [
|
26
|
-
'abstract_init',
|
27
|
-
]
|
28
|
-
|
29
|
-
A = TypeVar('A')
|
30
|
-
|
31
|
-
|
32
|
-
def abstract_init(
|
33
|
-
fn: Callable[..., A],
|
34
|
-
*args: Any,
|
35
|
-
rngs: Union[random.RandomState, Sequence[random.RandomState]] = random.DEFAULT,
|
36
|
-
**kwargs: Any,
|
37
|
-
) -> A:
|
38
|
-
"""
|
39
|
-
Compute the shape/dtype of ``fn`` without any FLOPs.
|
40
|
-
|
41
|
-
Here's an example::
|
42
|
-
|
43
|
-
>>> import brainstate
|
44
|
-
>>> class MLP:
|
45
|
-
... def __init__(self, n_in, n_mid, n_out):
|
46
|
-
... self.dense1 = brainstate.nn.Linear(n_in, n_mid)
|
47
|
-
... self.dense2 = brainstate.nn.Linear(n_mid, n_out)
|
48
|
-
|
49
|
-
>>> r = brainstate.augment.abstract_init(lambda: MLP(1, 2, 3))
|
50
|
-
>>> r
|
51
|
-
MLP(
|
52
|
-
dense1=Linear(
|
53
|
-
in_size=(1,),
|
54
|
-
out_size=(2,),
|
55
|
-
w_mask=None,
|
56
|
-
weight=ParamState(
|
57
|
-
value={'bias': ShapeDtypeStruct(shape=(2,), dtype=float32), 'weight': ShapeDtypeStruct(shape=(1, 2), dtype=float32)}
|
58
|
-
)
|
59
|
-
),
|
60
|
-
dense2=Linear(
|
61
|
-
in_size=(2,),
|
62
|
-
out_size=(3,),
|
63
|
-
w_mask=None,
|
64
|
-
weight=ParamState(
|
65
|
-
value={'bias': ShapeDtypeStruct(shape=(3,), dtype=float32), 'weight': ShapeDtypeStruct(shape=(2, 3), dtype=float32)}
|
66
|
-
)
|
67
|
-
)
|
68
|
-
)
|
69
|
-
|
70
|
-
Args:
|
71
|
-
fn: The function whose output shape should be evaluated.
|
72
|
-
*args: a positional argument tuple of arrays, scalars, or (nested) standard
|
73
|
-
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
|
74
|
-
those types. Since only the ``shape`` and ``dtype`` attributes are
|
75
|
-
accessed, one can use :class:`jax.ShapeDtypeStruct` or another container
|
76
|
-
that duck-types as ndarrays (note however that duck-typed objects cannot
|
77
|
-
be namedtuples because those are treated as standard Python containers).
|
78
|
-
**kwargs: a keyword argument dict of arrays, scalars, or (nested) standard
|
79
|
-
Python containers (pytrees) of those types. As in ``args``, array values
|
80
|
-
need only be duck-typed to have ``shape`` and ``dtype`` attributes.
|
81
|
-
rngs: a :class:`RandomState` or a sequence of :class:`RandomState` objects
|
82
|
-
representing the random number generators to use. If not provided, the
|
83
|
-
default random number generator will be used.
|
84
|
-
|
85
|
-
Returns:
|
86
|
-
out: a nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves.
|
87
|
-
|
88
|
-
"""
|
89
|
-
|
90
|
-
@functools.wraps(fn)
|
91
|
-
@restore_rngs(rngs=rngs)
|
92
|
-
def _eval_shape_fn(*args_, **kwargs_):
|
93
|
-
out = fn(*args_, **kwargs_)
|
94
|
-
assert isinstance(out, Node), 'The output of the function must be Node'
|
95
|
-
graph_def, treefy_states = flatten(out)
|
96
|
-
return graph_def, treefy_states
|
97
|
-
|
98
|
-
graph_def_, treefy_states_ = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
|
99
|
-
return unflatten(graph_def_, treefy_states_)
|