brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,111 @@
1
+ brainstate/__init__.py,sha256=bmCIZG6xMiVlZ5qWk-d5qpkqdnrWxFhLePtpu3ff5WI,6024
2
+ brainstate/_compatible_import.py,sha256=7thV_2F0FD5AF2DETjBfmtNb_2ZQzki8NxFgC62frg0,11037
3
+ brainstate/_compatible_import_test.py,sha256=6ka26Sa_Kk6F-Ar1HR6UaKJTHquXcUCWglgXBUOovcg,22762
4
+ brainstate/_deprecation.py,sha256=gSh36_TWLBgQAo0gNfOzscV9ssa26k3te9y25BG6O2w,8381
5
+ brainstate/_deprecation_test.py,sha256=5_tJ9JDhG79zLwe5MuBLsqcgl7gi-oXPgDTCujrPmz0,88325
6
+ brainstate/_error.py,sha256=6A5ILy17ZMMZIjS8LkajTZBDRnwv_Qait5x__h2Levo,1522
7
+ brainstate/_state.py,sha256=YCYKrX2xCTojIY15vtQl8WWjtJ0HNMg8y5LCUk1d_3Q,58096
8
+ brainstate/_state_test.py,sha256=wCoWTbvARVkhNokMLLBV6sEoHSE9OJRyuUVrNDcRiG8,1589
9
+ brainstate/_utils.py,sha256=cmUyO9ds1etrrpV4ucp1G8mDqE15g4ZtbivblH_cD9o,1613
10
+ brainstate/environ.py,sha256=BmQsvo1aZaMpckHXlJ45dZh9DUdnHHP9Q9JdLNCa9wA,42169
11
+ brainstate/environ_test.py,sha256=RdVmeP7irbk3_qNjwWoy-DSdpkTRpxqAABtNGZcbB2w,42418
12
+ brainstate/mixin.py,sha256=O58rGznYowz-hBQy7iKOhbrg3Dze7FAR8AU4CX8hUEI,44635
13
+ brainstate/mixin_test.py,sha256=6WmqJf34kT3Z5WaiCNDo3OV3ci0DIHCQp474zECDxEU,34718
14
+ brainstate/typing.py,sha256=pYiNI-9oHpH7HfjRKYxugK03KGiamCwweagMyO0rsi4,26301
15
+ brainstate/typing_test.py,sha256=2mmMW0uAzIo3_VXpT5Boq79BohxYkzBlHexBysFUGII,26240
16
+ brainstate/graph/__init__.py,sha256=kGVtHAnkiWR5MqDYQU0G3AobWnioGeDqjILA--RyDz8,846
17
+ brainstate/graph/_node.py,sha256=_XH8xx6_glsCK4KCsQnarACK8meyhCdfh3nWfUDko0k,6407
18
+ brainstate/graph/_node_test.py,sha256=sD2DS0AhDKOU5ZQm0cYz0llnJ6D60ftNfDpztk4i8cM,18687
19
+ brainstate/graph/_operation.py,sha256=n2HqfwPzG3f6QHish_7d8lNZRhHUyS4_YHTdPMGgdUk,54096
20
+ brainstate/graph/_operation_test.py,sha256=IVyrJh4io3sDgtrTEIAItGzNs2XEf7rO1rvI1r_KiII,39119
21
+ brainstate/nn/__init__.py,sha256=oEQ81xpWppPMUMCRalMhAhUU5faYs0aKcyydwjoxGVo,4759
22
+ brainstate/nn/_activations.py,sha256=6jHR67obYR1lpo-imVXmfd3m_NDyU0XZb8t-pVYDvUU,26917
23
+ brainstate/nn/_activations_test.py,sha256=Ikr8RYBaIpApVKUhY-XAWr6llEG7vYWS8YuqDHyTtBY,13438
24
+ brainstate/nn/_collective_ops.py,sha256=UditkuCy0f9ggkFwGUXyJDoFtDdVUIQU3xmD5YOxRKg,21280
25
+ brainstate/nn/_collective_ops_test.py,sha256=8mKQkfTjfwuO7DA1i_Yr4QD1yg2ZJgVpX7zhHFo9CuQ,25600
26
+ brainstate/nn/_common.py,sha256=UyJMJoVF9KfrToOX5Dbv-2s3CD49SsroLRfL17DLl4Q,7184
27
+ brainstate/nn/_common_test.py,sha256=NpKQpKPPU4I2nVpz1X_bOGC5R-RyCP2kYlyYqXWoKAE,5966
28
+ brainstate/nn/_conv.py,sha256=3cGToc5UoGN5jp4BUHlrK6md8O_0IcWMCBXFHBQg7nE,82106
29
+ brainstate/nn/_conv_test.py,sha256=65FlrteUxLQb8ckUUJFhaPZDsQZTlGVAp0HUGUJtt1M,30173
30
+ brainstate/nn/_delay.py,sha256=0brSPlKOC7_nNWr1zcdjXNHJFw_I9KUX7PGyVWXENTk,22331
31
+ brainstate/nn/_delay_test.py,sha256=FzBb8vXfse8HEcEid83hpa6aag6oj90mtcHYsDs0DOE,10376
32
+ brainstate/nn/_dropout.py,sha256=UotjW0PQO4gypfhtSqzkR4UVGkY0kNBncEB_poGL7Sc,22555
33
+ brainstate/nn/_dropout_test.py,sha256=5LEN_G6t2TZAbmvJdZ75h1nyKZTHHiNT_i1W-im5wo8,21461
34
+ brainstate/nn/_dynamics.py,sha256=V7GA1enppLRRDJCw146RI_CdwpVyRgHUYH9fu2-PJOA,44181
35
+ brainstate/nn/_dynamics_test.py,sha256=G1VkrWgKK4CXtcKJmevxiQQ9DVeyf1zcdMsAWyCN0Jw,2356
36
+ brainstate/nn/_elementwise.py,sha256=4kKzrbKn5luwnpY8n7IeaMOtBVVie8oPHEc64hSn8-w,34858
37
+ brainstate/nn/_elementwise_test.py,sha256=sbWlUyTB8oiu3PRHObTvmUaob99EjIGYE6k6bEpg6K8,27296
38
+ brainstate/nn/_embedding.py,sha256=eMwbz9udm6WGVEGeHwhvCr7sJA6kMFyaZsli7lL30eg,14942
39
+ brainstate/nn/_embedding_test.py,sha256=Gc0y6gHMEagaDrBJoAYQZMDTd47TYQNVrenWUwLWK_w,6242
40
+ brainstate/nn/_event_fixedprob.py,sha256=ZEnIyjDksxtUWWG5GXLcF-RHR1S33DZTw-rN2lHKs0g,9395
41
+ brainstate/nn/_event_fixedprob_test.py,sha256=rvTKxEzKwvctQc8-AxXjJ4p4D-if1va2O8KqQYW5nxY,3836
42
+ brainstate/nn/_event_linear.py,sha256=d0J54Sf9zBl926BzXoy4Oc1p96h9veU3f8YhWrZLRPk,2554
43
+ brainstate/nn/_event_linear_test.py,sha256=qzcGplDIwxTnZOs4JzD5GX_oNBYtcYFNaf3GpWo8pZY,3765
44
+ brainstate/nn/_exp_euler.py,sha256=W5ofUPe6UK8NXO917PY9_nNXOI9chSHgBmnDHMVXHoo,8635
45
+ brainstate/nn/_exp_euler_test.py,sha256=21qomGOo96YLmlEQok2hCByAcRpmqREhSr3kxmhKOm8,13014
46
+ brainstate/nn/_linear.py,sha256=olbo9AmC35UQBxECyAQW06XeucsJfBAAMoSfwOW6c7s,24047
47
+ brainstate/nn/_linear_test.py,sha256=5fHx4v4_54dH3Bsyapl2cobvVtMEu3DFR-jDKLbkJFw,17876
48
+ brainstate/nn/_metrics.py,sha256=TgALwv6i9La4Dm1WAkWDWxvxr9rkd7CJLGFy2sOGQbQ,36481
49
+ brainstate/nn/_metrics_test.py,sha256=XZiRndchRgEH0X8zsHzg0fsHNMxj43mnd883QChfSik,24104
50
+ brainstate/nn/_module.py,sha256=to280ubWAP-HiCV7LknMBqOhet0UkH9Oh-PkLkual_g,12775
51
+ brainstate/nn/_module_test.py,sha256=znjB7FU5evJENQ1Pqw7ZlOGC5faQe4-4VpjW60H8UWI,1414
52
+ brainstate/nn/_normalizations.py,sha256=exdIn627ph5pVHdYx_4NIaK_f1xcLaBj-YQ0Ui22CsA,50185
53
+ brainstate/nn/_normalizations_test.py,sha256=y5n7aTaUHRkyAAjq4Oj8dfButJC3ehm4KdUB7226Bow,23350
54
+ brainstate/nn/_paddings.py,sha256=3u3dbRFtPSlIsLMBYZmHcYDQ6HFl0u_d-yP7ZoYcCrA,32415
55
+ brainstate/nn/_paddings_test.py,sha256=uY9CRexf9sM5V6AzTfzm8214y4IoGqEbahy92yaKsWM,27409
56
+ brainstate/nn/_poolings.py,sha256=aiLDTgtbDEse2OZlJONdU1CCsBjZtsWqsnL8ffTSGpk,86045
57
+ brainstate/nn/_poolings_test.py,sha256=cZ4lyutUY-Iti6rkGWjqWZ72fXF5QvwX9mki4CQ1xDs,34164
58
+ brainstate/nn/_rnns.py,sha256=FmuVsUwQCl7QbePmYokd-ZxKwVgDqa8FdCto_dNFDjk,32947
59
+ brainstate/nn/_rnns_test.py,sha256=gqPGz64i_QUEXi2fUBrIaBEPZjTs8hm8S6HyPbs3I6M,22112
60
+ brainstate/nn/_utils.py,sha256=VK-Se53e1q-Ip4AtMOZ3SUzYw8u2UllLJRLRtEFRCRE,7403
61
+ brainstate/nn/_utils_test.py,sha256=uim2SkfNHrBZzNDvN0WOK8qeZC1kaeOd-UQDvrn_M24,14266
62
+ brainstate/nn/init.py,sha256=7iLHrL-ZHpU-g5d0PlusaqmtkWO7X_KNr2eLx90oHrc,25656
63
+ brainstate/nn/init_test.py,sha256=bfby6kovvhbc7CCEaohtQawUQj3w3GvaMSDAdTiP2ps,6200
64
+ brainstate/random/__init__.py,sha256=2k5aI3GI_ftOtuWO9rosRUhyr5OjnMty82b-CTBkOUU,8383
65
+ brainstate/random/_rand_funs.py,sha256=hBkZYTlBNHBLdHTLBecT7sja7mVItMF7TV3ao26vfj0,135339
66
+ brainstate/random/_rand_funs_test.py,sha256=mZaSRtlXCwJP3YfF5GsFelR-YACg-uH1gHWG-coZT_k,23040
67
+ brainstate/random/_rand_seed.py,sha256=mzSe2lsyhN_0eXhkD8gpRKmSbReQfZF0pZA89cjiiHU,24927
68
+ brainstate/random/_rand_seed_test.py,sha256=Y2VCAkUzciDaCfYZWPe_Ewmi3MylK-WzfPA7TzorV8Q,1491
69
+ brainstate/random/_rand_state.py,sha256=_62wIrdZKKvmtjBErB0lRCuf1vYEskFyPd4Nhp_FKDU,53591
70
+ brainstate/random/_rand_state_test.py,sha256=0Y1kx7rvkACQiiiAHK0plePtLxYeYTzvWUGt3MimSGI,19227
71
+ brainstate/transform/__init__.py,sha256=CfzUYGQFt9hAp57ZvffUBBhLjudgzJ9n_aTaor57iOk,2241
72
+ brainstate/transform/_ad_checkpoint.py,sha256=4dcNCEQVV_CPMSkE32URERDMpQHbyfdGeLT_Nvhyd4o,6912
73
+ brainstate/transform/_ad_checkpoint_test.py,sha256=fPXBjDxsLHbL2mhIU3x_F5BpitkXLpgIsRCRgm2Us6w,1697
74
+ brainstate/transform/_autograd.py,sha256=4zGSYa9TMn6bqzPJNLfU9UZGZRyYxmwMXTKZWO4w3QQ,39991
75
+ brainstate/transform/_autograd_test.py,sha256=saWG1_k3cRXpsyQDzQkOLGvsF7IIxG9aGnjrf5B3HNk,44112
76
+ brainstate/transform/_conditions.py,sha256=IIu_V2f7R74saOoURGUxwzI1RpsnxkzpfB2fes9plOs,11399
77
+ brainstate/transform/_conditions_test.py,sha256=MEuqRq6IFmyORRDi0qWvNo4pWKFyc8aNrW1v9Saqxj0,8493
78
+ brainstate/transform/_error_if.py,sha256=e9tp3wT5p4bEyjn_Za_SrPNOG3OIoPBMIrvG2CsZzvw,2680
79
+ brainstate/transform/_error_if_test.py,sha256=yn-qcZ6lZUWciIif4fJOpdpKzJFAAdfgzZm6FfPeq7U,1848
80
+ brainstate/transform/_eval_shape.py,sha256=BNbjiFHUsk-qfENiZf1K8yX-x7eIIuAyWR0CjBIBr5w,5355
81
+ brainstate/transform/_eval_shape_test.py,sha256=4A2NdHcpksiGPf_UmPlMPgHJnes9ciiQO20ZJHVzA9g,1355
82
+ brainstate/transform/_jit.py,sha256=qYsL3Z9nAAW0UyQe_AyvBEuJvqAan6iw3lN49o1oC0A,15421
83
+ brainstate/transform/_jit_test.py,sha256=ecw54dGQYdJq2J94itPrXBQrSDNvCY_htUD7z7y4HUM,4013
84
+ brainstate/transform/_loop_collect_return.py,sha256=HhjC2gq6qzliw4ofP16VxdtR5hW-NmDZdeHxuiLdYGk,25899
85
+ brainstate/transform/_loop_collect_return_test.py,sha256=BVK-b3CuDtTXciRaA_8t4751N4taQOnIPNzAelSts-k,1753
86
+ brainstate/transform/_loop_no_collection.py,sha256=ArPpNemMh4jJsq_vUWPxuagCnxTlONN--3P_-44qYq8,10156
87
+ brainstate/transform/_loop_no_collection_test.py,sha256=3bRo9_Oaypbw3asEevrgTK0WksxDAIZKJgeaWpt7nl8,1371
88
+ brainstate/transform/_make_jaxpr.py,sha256=A9ivekPuzLUCoUcgSLm5hUuWPj5_1_ksqn2KvztcWhY,73327
89
+ brainstate/transform/_make_jaxpr_test.py,sha256=dFAnEQavYlEddex6h42ulRkXbV5TSGCDHFXWAvPsy7w,49128
90
+ brainstate/transform/_mapping.py,sha256=L9Q_1M_GXJA0ZCJtty9x01wViyefgeHyOLga-Jo2AE8,21570
91
+ brainstate/transform/_mapping_test.py,sha256=yGG-QSo3epqtBvGX8VxrNDKiC-d6TMN_jv_uKPad6eM,7647
92
+ brainstate/transform/_progress_bar.py,sha256=kZ-mI5hbUQXhqKFVyo0qeKG_LvrR9ZIar7WkXyOeET0,8961
93
+ brainstate/transform/_random.py,sha256=ZTH5Smx5SvFy6El7qk-ihoYE9WII6TJXcsC9Vm7VgjA,5259
94
+ brainstate/transform/_unvmap.py,sha256=cW6fjs5Iy1YBB6Nx2mxlM4IzV8U99bEX5QjT8rBRDho,6319
95
+ brainstate/transform/_util.py,sha256=IYqJj7oyAYzm_m3d9WEsUQRKDdVLQaAWpwM5O8PD4YQ,11304
96
+ brainstate/util/__init__.py,sha256=anHdG5BIsMqcBQy7gt8lErKInZv1wf2NOLJGUqltyAQ,1154
97
+ brainstate/util/_others.py,sha256=HSP3ynNv2ocxPz5omQl2rxMTOjGkQNNprsTOCEANCjA,30837
98
+ brainstate/util/_others_test.py,sha256=gEiUybMxtv11tsD0jge9Y0K8idkWkHVBlTZz1gKXNAU,30535
99
+ brainstate/util/_pretty_pytree.py,sha256=0fVwW8qHtKrpJU4q2tbn8usSKOlG8lBALn6E6crpf64,46815
100
+ brainstate/util/_pretty_pytree_test.py,sha256=6DiYX_Rwyvv3XcuQHsU-PuWXmvOqSmiVVN4diVvp76o,22362
101
+ brainstate/util/_pretty_repr.py,sha256=fafG6SIFoPjaWmQcTHwmnEqbVvcTCoN_lKiBKBk1QTQ,13958
102
+ brainstate/util/_pretty_repr_test.py,sha256=yfaANMfozlM5E3IZv-LBaNpyvSVB8ftRp0r7_oAr5yA,23122
103
+ brainstate/util/filter.py,sha256=wY_XUF3OhrXSV1bZkTcVhlEPba4HP1l9N5aRW2zgxqQ,27410
104
+ brainstate/util/filter_test.py,sha256=ZfrEeOc1yMHYzrcSR3p4jbZGj7c_tXC8VcPq7H13q8E,31653
105
+ brainstate/util/struct.py,sha256=LYPLGDGfPuw14hhx5k4rb8msSH3yZPdAu_0CvjxPWwE,24505
106
+ brainstate/util/struct_test.py,sha256=q_fWsUH1ON35DKjUUAMq6VtYglqTDyBvd6WMVGD89EI,16526
107
+ brainstate-0.2.0.dist-info/licenses/LICENSE,sha256=RJ40fox7u2in2H8wvIS5DsPGlNHaA7JI024thFUlaZE,11348
108
+ brainstate-0.2.0.dist-info/METADATA,sha256=GrI-RT31978SYlGNnRP9xodMOYl1g6mPpIiP1XdUaOY,4421
109
+ brainstate-0.2.0.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
110
+ brainstate-0.2.0.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
111
+ brainstate-0.2.0.dist-info/RECORD,,
@@ -1,30 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """
17
- This module includes transformations for augmenting the functionalities of JAX code.
18
- """
19
-
20
- from ._autograd import GradientTransform, grad, vector_grad, hessian, jacobian, jacrev, jacfwd
21
- from ._eval_shape import abstract_init
22
- from ._mapping import vmap, pmap, map, vmap_new_states
23
- from ._random import restore_rngs
24
-
25
- __all__ = [
26
- 'GradientTransform', 'grad', 'vector_grad', 'hessian', 'jacobian', 'jacrev', 'jacfwd',
27
- 'abstract_init',
28
- 'vmap', 'pmap', 'map', 'vmap_new_states',
29
- 'restore_rngs',
30
- ]
@@ -1,99 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import functools
17
- from typing import Any, TypeVar, Callable, Sequence, Union
18
-
19
- import jax
20
-
21
- from brainstate import random
22
- from brainstate.graph import Node, flatten, unflatten
23
- from ._random import restore_rngs
24
-
25
- __all__ = [
26
- 'abstract_init',
27
- ]
28
-
29
- A = TypeVar('A')
30
-
31
-
32
- def abstract_init(
33
- fn: Callable[..., A],
34
- *args: Any,
35
- rngs: Union[random.RandomState, Sequence[random.RandomState]] = random.DEFAULT,
36
- **kwargs: Any,
37
- ) -> A:
38
- """
39
- Compute the shape/dtype of ``fn`` without any FLOPs.
40
-
41
- Here's an example::
42
-
43
- >>> import brainstate
44
- >>> class MLP:
45
- ... def __init__(self, n_in, n_mid, n_out):
46
- ... self.dense1 = brainstate.nn.Linear(n_in, n_mid)
47
- ... self.dense2 = brainstate.nn.Linear(n_mid, n_out)
48
-
49
- >>> r = brainstate.augment.abstract_init(lambda: MLP(1, 2, 3))
50
- >>> r
51
- MLP(
52
- dense1=Linear(
53
- in_size=(1,),
54
- out_size=(2,),
55
- w_mask=None,
56
- weight=ParamState(
57
- value={'bias': ShapeDtypeStruct(shape=(2,), dtype=float32), 'weight': ShapeDtypeStruct(shape=(1, 2), dtype=float32)}
58
- )
59
- ),
60
- dense2=Linear(
61
- in_size=(2,),
62
- out_size=(3,),
63
- w_mask=None,
64
- weight=ParamState(
65
- value={'bias': ShapeDtypeStruct(shape=(3,), dtype=float32), 'weight': ShapeDtypeStruct(shape=(2, 3), dtype=float32)}
66
- )
67
- )
68
- )
69
-
70
- Args:
71
- fn: The function whose output shape should be evaluated.
72
- *args: a positional argument tuple of arrays, scalars, or (nested) standard
73
- Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
74
- those types. Since only the ``shape`` and ``dtype`` attributes are
75
- accessed, one can use :class:`jax.ShapeDtypeStruct` or another container
76
- that duck-types as ndarrays (note however that duck-typed objects cannot
77
- be namedtuples because those are treated as standard Python containers).
78
- **kwargs: a keyword argument dict of arrays, scalars, or (nested) standard
79
- Python containers (pytrees) of those types. As in ``args``, array values
80
- need only be duck-typed to have ``shape`` and ``dtype`` attributes.
81
- rngs: a :class:`RandomState` or a sequence of :class:`RandomState` objects
82
- representing the random number generators to use. If not provided, the
83
- default random number generator will be used.
84
-
85
- Returns:
86
- out: a nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves.
87
-
88
- """
89
-
90
- @functools.wraps(fn)
91
- @restore_rngs(rngs=rngs)
92
- def _eval_shape_fn(*args_, **kwargs_):
93
- out = fn(*args_, **kwargs_)
94
- assert isinstance(out, Node), 'The output of the function must be Node'
95
- graph_def, treefy_states = flatten(out)
96
- return graph_def, treefy_states
97
-
98
- graph_def_, treefy_states_ = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
99
- return unflatten(graph_def_, treefy_states_)