brainstate 0.1.10__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (201) hide show
  1. {brainstate-0.1.10 → brainstate-0.2.0}/PKG-INFO +34 -17
  2. {brainstate-0.1.10 → brainstate-0.2.0}/README.md +3 -3
  3. brainstate-0.2.0/brainstate/__init__.py +169 -0
  4. brainstate-0.2.0/brainstate/_compatible_import.py +340 -0
  5. brainstate-0.2.0/brainstate/_compatible_import_test.py +681 -0
  6. brainstate-0.2.0/brainstate/_deprecation.py +210 -0
  7. brainstate-0.2.0/brainstate/_deprecation_test.py +2319 -0
  8. brainstate-0.1.10/brainstate/util/error.py → brainstate-0.2.0/brainstate/_error.py +10 -20
  9. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/_state.py +94 -47
  10. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/_state_test.py +1 -1
  11. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/_utils.py +1 -1
  12. brainstate-0.2.0/brainstate/environ.py +1495 -0
  13. brainstate-0.2.0/brainstate/environ_test.py +1223 -0
  14. brainstate-0.1.10/brainstate/transform.py → brainstate-0.2.0/brainstate/graph/__init__.py +6 -7
  15. brainstate-0.2.0/brainstate/graph/_node.py +240 -0
  16. brainstate-0.2.0/brainstate/graph/_node_test.py +589 -0
  17. brainstate-0.1.10/brainstate/graph/_graph_operation.py → brainstate-0.2.0/brainstate/graph/_operation.py +632 -746
  18. brainstate-0.2.0/brainstate/graph/_operation_test.py +1147 -0
  19. brainstate-0.2.0/brainstate/mixin.py +1433 -0
  20. brainstate-0.2.0/brainstate/mixin_test.py +1017 -0
  21. brainstate-0.2.0/brainstate/nn/__init__.py +137 -0
  22. brainstate-0.2.0/brainstate/nn/_activations.py +1100 -0
  23. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/nn/_activations_test.py +109 -86
  24. brainstate-0.2.0/brainstate/nn/_collective_ops.py +633 -0
  25. brainstate-0.2.0/brainstate/nn/_collective_ops_test.py +774 -0
  26. brainstate-0.2.0/brainstate/nn/_common.py +226 -0
  27. brainstate-0.2.0/brainstate/nn/_common_test.py +154 -0
  28. brainstate-0.2.0/brainstate/nn/_conv.py +2010 -0
  29. brainstate-0.2.0/brainstate/nn/_conv_test.py +849 -0
  30. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/nn/_delay.py +15 -28
  31. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/nn/_delay_test.py +25 -20
  32. brainstate-0.2.0/brainstate/nn/_dropout.py +618 -0
  33. brainstate-0.2.0/brainstate/nn/_dropout_test.py +477 -0
  34. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/nn/_dynamics.py +14 -90
  35. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/nn/_dynamics_test.py +1 -12
  36. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/nn/_elementwise.py +492 -313
  37. brainstate-0.2.0/brainstate/nn/_elementwise_test.py +830 -0
  38. brainstate-0.2.0/brainstate/nn/_embedding.py +408 -0
  39. brainstate-0.2.0/brainstate/nn/_embedding_test.py +156 -0
  40. brainstate-0.1.10/brainstate/nn/_fixedprob.py → brainstate-0.2.0/brainstate/nn/_event_fixedprob.py +10 -16
  41. brainstate-0.1.10/brainstate/nn/_fixedprob_test.py → brainstate-0.2.0/brainstate/nn/_event_fixedprob_test.py +6 -5
  42. brainstate-0.1.10/brainstate/nn/_linear_mv.py → brainstate-0.2.0/brainstate/nn/_event_linear.py +2 -2
  43. brainstate-0.1.10/brainstate/nn/_linear_mv_test.py → brainstate-0.2.0/brainstate/nn/_event_linear_test.py +6 -5
  44. brainstate-0.2.0/brainstate/nn/_exp_euler.py +254 -0
  45. brainstate-0.2.0/brainstate/nn/_exp_euler_test.py +377 -0
  46. brainstate-0.2.0/brainstate/nn/_linear.py +744 -0
  47. brainstate-0.2.0/brainstate/nn/_linear_test.py +475 -0
  48. brainstate-0.2.0/brainstate/nn/_metrics.py +1070 -0
  49. brainstate-0.2.0/brainstate/nn/_metrics_test.py +611 -0
  50. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/nn/_module.py +10 -3
  51. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/nn/_module_test.py +1 -1
  52. brainstate-0.2.0/brainstate/nn/_normalizations.py +1334 -0
  53. brainstate-0.2.0/brainstate/nn/_normalizations_test.py +699 -0
  54. brainstate-0.2.0/brainstate/nn/_paddings.py +1020 -0
  55. brainstate-0.2.0/brainstate/nn/_paddings_test.py +723 -0
  56. brainstate-0.2.0/brainstate/nn/_poolings.py +2239 -0
  57. brainstate-0.2.0/brainstate/nn/_poolings_test.py +953 -0
  58. brainstate-0.1.10/brainstate/nn/_rate_rnns.py → brainstate-0.2.0/brainstate/nn/_rnns.py +446 -54
  59. brainstate-0.2.0/brainstate/nn/_rnns_test.py +593 -0
  60. brainstate-0.2.0/brainstate/nn/_utils.py +216 -0
  61. brainstate-0.2.0/brainstate/nn/_utils_test.py +402 -0
  62. brainstate-0.1.10/brainstate/init/_random_inits.py → brainstate-0.2.0/brainstate/nn/init.py +301 -45
  63. brainstate-0.1.10/brainstate/init/_random_inits_test.py → brainstate-0.2.0/brainstate/nn/init_test.py +51 -20
  64. brainstate-0.2.0/brainstate/random/__init__.py +270 -0
  65. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/random/_rand_funs.py +668 -346
  66. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/random/_rand_funs_test.py +74 -1
  67. brainstate-0.2.0/brainstate/random/_rand_seed.py +675 -0
  68. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/random/_rand_seed_test.py +1 -1
  69. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/random/_rand_state.py +601 -393
  70. brainstate-0.2.0/brainstate/random/_rand_state_test.py +551 -0
  71. brainstate-0.2.0/brainstate/transform/__init__.py +59 -0
  72. brainstate-0.2.0/brainstate/transform/_ad_checkpoint.py +176 -0
  73. {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_ad_checkpoint_test.py +1 -1
  74. {brainstate-0.1.10/brainstate/augment → brainstate-0.2.0/brainstate/transform}/_autograd.py +360 -113
  75. {brainstate-0.1.10/brainstate/augment → brainstate-0.2.0/brainstate/transform}/_autograd_test.py +2 -2
  76. brainstate-0.2.0/brainstate/transform/_conditions.py +316 -0
  77. {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_conditions_test.py +11 -11
  78. {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_error_if.py +22 -20
  79. {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_error_if_test.py +1 -1
  80. brainstate-0.2.0/brainstate/transform/_eval_shape.py +145 -0
  81. {brainstate-0.1.10/brainstate/augment → brainstate-0.2.0/brainstate/transform}/_eval_shape_test.py +1 -1
  82. {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_jit.py +99 -46
  83. {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_jit_test.py +3 -3
  84. {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_loop_collect_return.py +219 -80
  85. {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_loop_collect_return_test.py +1 -1
  86. {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_loop_no_collection.py +133 -34
  87. {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_loop_no_collection_test.py +2 -2
  88. brainstate-0.2.0/brainstate/transform/_make_jaxpr.py +2016 -0
  89. brainstate-0.2.0/brainstate/transform/_make_jaxpr_test.py +1510 -0
  90. brainstate-0.2.0/brainstate/transform/_mapping.py +529 -0
  91. brainstate-0.2.0/brainstate/transform/_mapping_test.py +194 -0
  92. {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_progress_bar.py +78 -25
  93. {brainstate-0.1.10/brainstate/augment → brainstate-0.2.0/brainstate/transform}/_random.py +65 -45
  94. {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_unvmap.py +102 -5
  95. brainstate-0.2.0/brainstate/transform/_util.py +286 -0
  96. brainstate-0.2.0/brainstate/typing.py +837 -0
  97. brainstate-0.2.0/brainstate/typing_test.py +780 -0
  98. {brainstate-0.1.10/brainstate/random → brainstate-0.2.0/brainstate/util}/__init__.py +12 -9
  99. brainstate-0.2.0/brainstate/util/_others.py +1025 -0
  100. brainstate-0.2.0/brainstate/util/_others_test.py +962 -0
  101. brainstate-0.2.0/brainstate/util/_pretty_pytree.py +1301 -0
  102. brainstate-0.2.0/brainstate/util/_pretty_pytree_test.py +675 -0
  103. brainstate-0.1.10/brainstate/util/pretty_repr.py → brainstate-0.2.0/brainstate/util/_pretty_repr.py +161 -27
  104. brainstate-0.2.0/brainstate/util/_pretty_repr_test.py +696 -0
  105. brainstate-0.2.0/brainstate/util/filter.py +945 -0
  106. brainstate-0.2.0/brainstate/util/filter_test.py +912 -0
  107. brainstate-0.2.0/brainstate/util/struct.py +910 -0
  108. brainstate-0.2.0/brainstate/util/struct_test.py +602 -0
  109. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate.egg-info/PKG-INFO +34 -17
  110. brainstate-0.2.0/brainstate.egg-info/SOURCES.txt +114 -0
  111. brainstate-0.2.0/brainstate.egg-info/requires.txt +31 -0
  112. {brainstate-0.1.10 → brainstate-0.2.0}/pyproject.toml +50 -14
  113. brainstate-0.1.10/brainstate/__init__.py +0 -58
  114. brainstate-0.1.10/brainstate/_compatible_import.py +0 -148
  115. brainstate-0.1.10/brainstate/augment/__init__.py +0 -30
  116. brainstate-0.1.10/brainstate/augment/_eval_shape.py +0 -99
  117. brainstate-0.1.10/brainstate/augment/_mapping.py +0 -1060
  118. brainstate-0.1.10/brainstate/augment/_mapping_test.py +0 -597
  119. brainstate-0.1.10/brainstate/compile/__init__.py +0 -38
  120. brainstate-0.1.10/brainstate/compile/_ad_checkpoint.py +0 -204
  121. brainstate-0.1.10/brainstate/compile/_conditions.py +0 -256
  122. brainstate-0.1.10/brainstate/compile/_make_jaxpr.py +0 -888
  123. brainstate-0.1.10/brainstate/compile/_make_jaxpr_test.py +0 -156
  124. brainstate-0.1.10/brainstate/compile/_util.py +0 -147
  125. brainstate-0.1.10/brainstate/environ.py +0 -563
  126. brainstate-0.1.10/brainstate/environ_test.py +0 -62
  127. brainstate-0.1.10/brainstate/functional/__init__.py +0 -27
  128. brainstate-0.1.10/brainstate/graph/__init__.py +0 -29
  129. brainstate-0.1.10/brainstate/graph/_graph_node.py +0 -244
  130. brainstate-0.1.10/brainstate/graph/_graph_node_test.py +0 -73
  131. brainstate-0.1.10/brainstate/graph/_graph_operation_test.py +0 -563
  132. brainstate-0.1.10/brainstate/init/__init__.py +0 -26
  133. brainstate-0.1.10/brainstate/init/_base.py +0 -52
  134. brainstate-0.1.10/brainstate/init/_generic.py +0 -244
  135. brainstate-0.1.10/brainstate/init/_regular_inits.py +0 -105
  136. brainstate-0.1.10/brainstate/init/_regular_inits_test.py +0 -50
  137. brainstate-0.1.10/brainstate/mixin.py +0 -365
  138. brainstate-0.1.10/brainstate/mixin_test.py +0 -77
  139. brainstate-0.1.10/brainstate/nn/__init__.py +0 -135
  140. brainstate-0.1.10/brainstate/nn/_activations.py +0 -808
  141. brainstate-0.1.10/brainstate/nn/_collective_ops.py +0 -514
  142. brainstate-0.1.10/brainstate/nn/_collective_ops_test.py +0 -43
  143. brainstate-0.1.10/brainstate/nn/_common.py +0 -178
  144. brainstate-0.1.10/brainstate/nn/_conv.py +0 -501
  145. brainstate-0.1.10/brainstate/nn/_conv_test.py +0 -238
  146. brainstate-0.1.10/brainstate/nn/_dropout.py +0 -426
  147. brainstate-0.1.10/brainstate/nn/_dropout_test.py +0 -100
  148. brainstate-0.1.10/brainstate/nn/_elementwise_test.py +0 -169
  149. brainstate-0.1.10/brainstate/nn/_embedding.py +0 -58
  150. brainstate-0.1.10/brainstate/nn/_exp_euler.py +0 -92
  151. brainstate-0.1.10/brainstate/nn/_exp_euler_test.py +0 -35
  152. brainstate-0.1.10/brainstate/nn/_inputs.py +0 -608
  153. brainstate-0.1.10/brainstate/nn/_linear.py +0 -424
  154. brainstate-0.1.10/brainstate/nn/_linear_test.py +0 -107
  155. brainstate-0.1.10/brainstate/nn/_ltp.py +0 -28
  156. brainstate-0.1.10/brainstate/nn/_neuron.py +0 -705
  157. brainstate-0.1.10/brainstate/nn/_neuron_test.py +0 -161
  158. brainstate-0.1.10/brainstate/nn/_normalizations.py +0 -975
  159. brainstate-0.1.10/brainstate/nn/_normalizations_test.py +0 -73
  160. brainstate-0.1.10/brainstate/nn/_others.py +0 -46
  161. brainstate-0.1.10/brainstate/nn/_poolings.py +0 -1177
  162. brainstate-0.1.10/brainstate/nn/_poolings_test.py +0 -217
  163. brainstate-0.1.10/brainstate/nn/_projection.py +0 -486
  164. brainstate-0.1.10/brainstate/nn/_rate_rnns_test.py +0 -63
  165. brainstate-0.1.10/brainstate/nn/_readout.py +0 -209
  166. brainstate-0.1.10/brainstate/nn/_readout_test.py +0 -53
  167. brainstate-0.1.10/brainstate/nn/_stp.py +0 -236
  168. brainstate-0.1.10/brainstate/nn/_synapse.py +0 -505
  169. brainstate-0.1.10/brainstate/nn/_synapse_test.py +0 -131
  170. brainstate-0.1.10/brainstate/nn/_synaptic_projection.py +0 -423
  171. brainstate-0.1.10/brainstate/nn/_synouts.py +0 -162
  172. brainstate-0.1.10/brainstate/nn/_synouts_test.py +0 -57
  173. brainstate-0.1.10/brainstate/nn/_utils.py +0 -89
  174. brainstate-0.1.10/brainstate/nn/metrics.py +0 -388
  175. brainstate-0.1.10/brainstate/optim/__init__.py +0 -38
  176. brainstate-0.1.10/brainstate/optim/_base.py +0 -64
  177. brainstate-0.1.10/brainstate/optim/_lr_scheduler.py +0 -448
  178. brainstate-0.1.10/brainstate/optim/_lr_scheduler_test.py +0 -50
  179. brainstate-0.1.10/brainstate/optim/_optax_optimizer.py +0 -152
  180. brainstate-0.1.10/brainstate/optim/_optax_optimizer_test.py +0 -53
  181. brainstate-0.1.10/brainstate/optim/_sgd_optimizer.py +0 -1104
  182. brainstate-0.1.10/brainstate/random/_rand_seed.py +0 -210
  183. brainstate-0.1.10/brainstate/random/_random_for_unit.py +0 -52
  184. brainstate-0.1.10/brainstate/surrogate.py +0 -1957
  185. brainstate-0.1.10/brainstate/typing.py +0 -304
  186. brainstate-0.1.10/brainstate/util/__init__.py +0 -50
  187. brainstate-0.1.10/brainstate/util/caller.py +0 -98
  188. brainstate-0.1.10/brainstate/util/filter.py +0 -469
  189. brainstate-0.1.10/brainstate/util/others.py +0 -540
  190. brainstate-0.1.10/brainstate/util/pretty_pytree.py +0 -945
  191. brainstate-0.1.10/brainstate/util/pretty_pytree_test.py +0 -159
  192. brainstate-0.1.10/brainstate/util/pretty_table.py +0 -2954
  193. brainstate-0.1.10/brainstate/util/scaling.py +0 -258
  194. brainstate-0.1.10/brainstate/util/struct.py +0 -523
  195. brainstate-0.1.10/brainstate.egg-info/SOURCES.txt +0 -134
  196. brainstate-0.1.10/brainstate.egg-info/requires.txt +0 -8
  197. brainstate-0.1.10/setup.py +0 -96
  198. {brainstate-0.1.10 → brainstate-0.2.0}/LICENSE +0 -0
  199. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate.egg-info/dependency_links.txt +0 -0
  200. {brainstate-0.1.10 → brainstate-0.2.0}/brainstate.egg-info/top_level.txt +0 -0
  201. {brainstate-0.1.10 → brainstate-0.2.0}/setup.cfg +0 -0
@@ -1,14 +1,15 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: brainstate
3
- Version: 0.1.10
4
- Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
- Home-page: https://github.com/chaobrain/brainstate
6
- Author: BrainState Developers
3
+ Version: 0.2.0
4
+ Summary: A State-based Transformation System for Brain Modeling.
7
5
  Author-email: BrainState Developers <chao.brain@qq.com>
8
6
  License: Apache-2.0 license
9
- Project-URL: homepage, http://github.com/chaobrain/brainstate
10
- Project-URL: repository, http://github.com/chaobrain/brainstate
11
- Keywords: computational neuroscience,brain-inspired computation,brain dynamics programming
7
+ Project-URL: homepage, https://github.com/chaobrain/brainstate
8
+ Project-URL: repository, https://github.com/chaobrain/brainstate
9
+ Project-URL: Documentation, https://brainstate.readthedocs.io/
10
+ Project-URL: Source Code, https://github.com/chaobrain/brainstate
11
+ Project-URL: Bug Tracker, https://github.com/chaobrain/brainstate/issues
12
+ Keywords: computational neuroscience,brain-inspired computing,brain simulation,brain modeling,spiking neural networks
12
13
  Classifier: Natural Language :: English
13
14
  Classifier: Operating System :: OS Independent
14
15
  Classifier: Development Status :: 4 - Beta
@@ -28,20 +29,36 @@ Classifier: Topic :: Software Development :: Libraries
28
29
  Requires-Python: >=3.10
29
30
  Description-Content-Type: text/markdown
30
31
  License-File: LICENSE
31
- Requires-Dist: jax
32
- Requires-Dist: jaxlib
33
- Requires-Dist: numpy
34
- Requires-Dist: brainunit>=0.1.0
32
+ Requires-Dist: numpy>=1.15
33
+ Requires-Dist: tqdm
34
+ Requires-Dist: brainunit
35
35
  Requires-Dist: brainevent
36
+ Provides-Extra: cpu
37
+ Requires-Dist: jax[cpu]; extra == "cpu"
38
+ Requires-Dist: brainunit; extra == "cpu"
39
+ Requires-Dist: brainevent; extra == "cpu"
40
+ Provides-Extra: cuda12
41
+ Requires-Dist: jax[cuda12]; extra == "cuda12"
42
+ Requires-Dist: brainunit; extra == "cuda12"
43
+ Requires-Dist: brainevent; extra == "cuda12"
44
+ Provides-Extra: cuda13
45
+ Requires-Dist: jax[cuda13]; extra == "cuda13"
46
+ Requires-Dist: brainunit; extra == "cuda13"
47
+ Requires-Dist: brainevent; extra == "cuda13"
48
+ Provides-Extra: tpu
49
+ Requires-Dist: jax[tpu]; extra == "tpu"
50
+ Requires-Dist: brainunit; extra == "tpu"
51
+ Requires-Dist: brainevent; extra == "tpu"
36
52
  Provides-Extra: testing
53
+ Requires-Dist: absl-py; extra == "testing"
37
54
  Requires-Dist: pytest; extra == "testing"
38
- Dynamic: author
39
- Dynamic: home-page
55
+ Requires-Dist: jax; extra == "testing"
56
+ Requires-Dist: brainunit; extra == "testing"
57
+ Requires-Dist: brainevent; extra == "testing"
40
58
  Dynamic: license-file
41
- Dynamic: requires-python
42
59
 
43
60
 
44
- # A ``State``-based Transformation System for Program Compilation and Augmentation
61
+ # A ``State``-based Transformation System for Brain Modeling
45
62
 
46
63
 
47
64
 
@@ -84,8 +101,8 @@ The official documentation is hosted on Read the Docs: [https://brainstate.readt
84
101
 
85
102
 
86
103
 
87
- ## See also the brain modeling ecosystem
104
+ ## See also the ecosystem
88
105
 
89
- We are building the brain modeling ecosystem: https://brainmodeling.readthedocs.io/
106
+ ``brainstate`` is one part of our brain simulation ecosystem: https://brainmodeling.readthedocs.io/
90
107
 
91
108
 
@@ -1,5 +1,5 @@
1
1
 
2
- # A ``State``-based Transformation System for Program Compilation and Augmentation
2
+ # A ``State``-based Transformation System for Brain Modeling
3
3
 
4
4
 
5
5
 
@@ -42,8 +42,8 @@ The official documentation is hosted on Read the Docs: [https://brainstate.readt
42
42
 
43
43
 
44
44
 
45
- ## See also the brain modeling ecosystem
45
+ ## See also the ecosystem
46
46
 
47
- We are building the brain modeling ecosystem: https://brainmodeling.readthedocs.io/
47
+ ``brainstate`` is one part of our brain simulation ecosystem: https://brainmodeling.readthedocs.io/
48
48
 
49
49
 
@@ -0,0 +1,169 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ A ``State``-based Transformation System for Program Compilation and Augmentation
18
+ """
19
+
20
+ __version__ = "0.2.0"
21
+ __versio_info__ = (0, 2, 0)
22
+
23
+ from . import environ
24
+ from . import graph
25
+ from . import mixin
26
+ from . import nn
27
+ from . import random
28
+ from . import transform
29
+ from . import typing
30
+ from . import util
31
+ from ._error import *
32
+ from ._error import __all__ as _error_all
33
+ from ._state import *
34
+ from ._state import __all__ as _state_all
35
+
36
+ # Create deprecated module proxies with scoped APIs
37
+ from ._deprecation import create_deprecated_module_proxy
38
+
39
+ # Augment module scope
40
+ _augment_apis = {
41
+ 'GradientTransform': 'brainstate.transform._autograd',
42
+ 'grad': 'brainstate.transform._autograd',
43
+ 'vector_grad': 'brainstate.transform._autograd',
44
+ 'hessian': 'brainstate.transform._autograd',
45
+ 'jacobian': 'brainstate.transform._autograd',
46
+ 'jacrev': 'brainstate.transform._autograd',
47
+ 'jacfwd': 'brainstate.transform._autograd',
48
+ 'abstract_init': 'brainstate.transform._eval_shape',
49
+ 'vmap': 'brainstate.transform._mapping',
50
+ 'pmap': 'brainstate.transform._mapping',
51
+ 'map': 'brainstate.transform._mapping',
52
+ 'vmap_new_states': 'brainstate.transform._mapping',
53
+ 'restore_rngs': 'brainstate.transform._random',
54
+ }
55
+
56
+ augment = create_deprecated_module_proxy(
57
+ deprecated_name='brainstate.augment',
58
+ replacement_module=transform,
59
+ replacement_name='brainstate.transform',
60
+ scoped_apis=_augment_apis
61
+ )
62
+
63
+ # Compile module scope
64
+ _compile_apis = {
65
+ 'checkpoint': 'brainstate.transform._ad_checkpoint',
66
+ 'remat': 'brainstate.transform._ad_checkpoint',
67
+ 'cond': 'brainstate.transform._conditions',
68
+ 'switch': 'brainstate.transform._conditions',
69
+ 'ifelse': 'brainstate.transform._conditions',
70
+ 'jit_error_if': 'brainstate.transform._error_if',
71
+ 'jit': 'brainstate.transform._jit',
72
+ 'scan': 'brainstate.transform._loop_collect_return',
73
+ 'checkpointed_scan': 'brainstate.transform._loop_collect_return',
74
+ 'for_loop': 'brainstate.transform._loop_collect_return',
75
+ 'checkpointed_for_loop': 'brainstate.transform._loop_collect_return',
76
+ 'while_loop': 'brainstate.transform._loop_no_collection',
77
+ 'bounded_while_loop': 'brainstate.transform._loop_no_collection',
78
+ 'StatefulFunction': 'brainstate.transform._make_jaxpr',
79
+ 'make_jaxpr': 'brainstate.transform._make_jaxpr',
80
+ 'ProgressBar': 'brainstate.transform._progress_bar',
81
+ }
82
+
83
+ compile = create_deprecated_module_proxy(
84
+ deprecated_name='brainstate.compile',
85
+ replacement_module=transform,
86
+ replacement_name='brainstate.transform',
87
+ scoped_apis=_compile_apis
88
+ )
89
+
90
+ # Functional module scope - use direct attribute access from nn module
91
+ _functional_apis = {
92
+ 'weight_standardization': 'brainstate.nn._normalizations',
93
+ 'clip_grad_norm': 'brainstate.nn._others',
94
+ 'tanh': 'brainstate.nn._activations',
95
+ 'relu': 'brainstate.nn._activations',
96
+ 'squareplus': 'brainstate.nn._activations',
97
+ 'softplus': 'brainstate.nn._activations',
98
+ 'soft_sign': 'brainstate.nn._activations',
99
+ 'sigmoid': 'brainstate.nn._activations',
100
+ 'silu': 'brainstate.nn._activations',
101
+ 'swish': 'brainstate.nn._activations',
102
+ 'log_sigmoid': 'brainstate.nn._activations',
103
+ 'elu': 'brainstate.nn._activations',
104
+ 'leaky_relu': 'brainstate.nn._activations',
105
+ 'hard_tanh': 'brainstate.nn._activations',
106
+ 'celu': 'brainstate.nn._activations',
107
+ 'selu': 'brainstate.nn._activations',
108
+ 'gelu': 'brainstate.nn._activations',
109
+ 'glu': 'brainstate.nn._activations',
110
+ 'logsumexp': 'brainstate.nn._activations',
111
+ 'log_softmax': 'brainstate.nn._activations',
112
+ 'softmax': 'brainstate.nn._activations',
113
+ 'standardize': 'brainstate.nn._activations',
114
+ 'relu6': 'brainstate.nn._activations',
115
+ 'hard_sigmoid': 'brainstate.nn._activations',
116
+ 'sparse_plus': 'brainstate.nn._activations',
117
+ 'hard_silu': 'brainstate.nn._activations',
118
+ 'hard_swish': 'brainstate.nn._activations',
119
+ 'hard_shrink': 'brainstate.nn._activations',
120
+ 'rrelu': 'brainstate.nn._activations',
121
+ 'mish': 'brainstate.nn._activations',
122
+ 'soft_shrink': 'brainstate.nn._activations',
123
+ 'prelu': 'brainstate.nn._activations',
124
+ 'softmin': 'brainstate.nn._activations',
125
+ 'one_hot': 'brainstate.nn._activations',
126
+ 'sparse_sigmoid': 'brainstate.nn._activations',
127
+ }
128
+
129
+ functional = create_deprecated_module_proxy(
130
+ deprecated_name='brainstate.functional',
131
+ replacement_module=nn,
132
+ replacement_name='brainstate.nn',
133
+ scoped_apis=_functional_apis
134
+ )
135
+
136
+
137
+ def __getattr__(name):
138
+ if name in ['surrogate', 'init', 'optim']:
139
+ import warnings
140
+ warnings.warn(
141
+ f"brainstate.{name} module is deprecated and will be removed in a future version. "
142
+ f"Please use braintools.{name} instead.",
143
+ DeprecationWarning,
144
+ stacklevel=2
145
+ )
146
+ import braintools
147
+ return getattr(braintools, name)
148
+ raise AttributeError(
149
+ f'module {__name__!r} has no attribute {name!r}'
150
+ )
151
+
152
+
153
+ __all__ = [
154
+ 'environ',
155
+ 'graph',
156
+ 'mixin',
157
+ 'nn',
158
+ 'random',
159
+ 'transform',
160
+ 'typing',
161
+ 'util',
162
+ # Deprecated modules
163
+ 'augment',
164
+ 'compile',
165
+ 'functional',
166
+ ]
167
+ __all__ = __all__ + _state_all + _error_all
168
+ del _state_all, create_deprecated_module_proxy, _augment_apis, _compile_apis, _functional_apis
169
+ del _error_all
@@ -0,0 +1,340 @@
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ """
19
+ Compatibility layer for JAX version differences.
20
+
21
+ This module provides a compatibility layer to handle differences between various
22
+ versions of JAX, ensuring that BrainState works correctly across different JAX
23
+ versions. It imports the appropriate modules and functions based on the detected
24
+ JAX version and provides fallback implementations when necessary.
25
+
26
+ Key Features:
27
+ - Version-aware imports for JAX core functionality
28
+ - Compatibility wrappers for changed APIs
29
+ - Fallback implementations for deprecated functions
30
+ - Type-safe utility functions
31
+
32
+ Examples:
33
+ Basic usage:
34
+
35
+ >>> from brainstate._compatible_import import safe_map, safe_zip
36
+ >>> result = safe_map(lambda x: x * 2, [1, 2, 3])
37
+ >>> pairs = safe_zip([1, 2, 3], ['a', 'b', 'c'])
38
+
39
+ Using JAX core types:
40
+
41
+ >>> from brainstate._compatible_import import Primitive, ClosedJaxpr
42
+ >>> # These imports work across different JAX versions
43
+ """
44
+
45
+ from contextlib import contextmanager
46
+ from functools import partial
47
+ from typing import Iterable, Hashable, TypeVar, Callable
48
+
49
+ import jax
50
+ from jax.core import get_aval, Tracer
51
+ from saiunit._compatible_import import wrap_init
52
+
53
+ __all__ = [
54
+ 'ClosedJaxpr',
55
+ 'Primitive',
56
+ 'extend_axis_env_nd',
57
+ 'jaxpr_as_fun',
58
+ 'get_aval',
59
+ 'Tracer',
60
+ 'to_concrete_aval',
61
+ 'safe_map',
62
+ 'safe_zip',
63
+ 'unzip2',
64
+ 'wraps',
65
+ 'Device',
66
+ 'wrap_init',
67
+ 'Var',
68
+ 'JaxprEqn',
69
+ 'Jaxpr',
70
+ 'Literal',
71
+
72
+ 'make_iota', 'to_elt', 'BatchTracer', 'BatchTrace',
73
+ ]
74
+
75
+ T = TypeVar("T")
76
+ T1 = TypeVar("T1")
77
+ T2 = TypeVar("T2")
78
+ T3 = TypeVar("T3")
79
+
80
+ if jax.__version_info__ < (0, 5, 0):
81
+ from jax.lib.xla_client import Device
82
+ else:
83
+ from jax import Device
84
+
85
+ if jax.__version_info__ < (0, 7, 1):
86
+ from jax.interpreters.batching import make_iota, to_elt, BatchTracer, BatchTrace
87
+ else:
88
+ from jax._src.interpreters.batching import make_iota, to_elt, BatchTracer, BatchTrace
89
+
90
+ if jax.__version_info__ < (0, 4, 38):
91
+ from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
92
+ from jax.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
93
+ else:
94
+ from jax.extend.core import ClosedJaxpr, Primitive, jaxpr_as_fun
95
+ from jax.extend.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
96
+ from jax.core import trace_ctx
97
+
98
+
99
+ @contextmanager
100
+ def extend_axis_env_nd(name_size_pairs: Iterable[tuple[Hashable, int]]):
101
+ """
102
+ Context manager to temporarily extend the JAX axis environment.
103
+
104
+ Extends the current JAX axis environment with new named axes for
105
+ vectorized computations, then restores the previous environment.
106
+
107
+ Args:
108
+ name_size_pairs: Iterable of (name, size) tuples specifying
109
+ the named axes to add to the environment.
110
+
111
+ Yields:
112
+ None: Context with extended axis environment.
113
+
114
+ Examples:
115
+ >>> with extend_axis_env_nd([('batch', 32), ('seq', 128)]):
116
+ ... # Code using vectorized operations with named axes
117
+ ... pass
118
+ """
119
+ prev = trace_ctx.axis_env
120
+ try:
121
+ trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs))
122
+ yield
123
+ finally:
124
+ trace_ctx.set_axis_env(prev)
125
+
126
+ if jax.__version_info__ < (0, 6, 0):
127
+ from jax.util import safe_map, safe_zip, unzip2, wraps
128
+
129
+ else:
130
+ def safe_map(f, *args):
131
+ """
132
+ Map a function over multiple sequences with length checking.
133
+
134
+ Applies a function to corresponding elements from multiple sequences,
135
+ ensuring all sequences have the same length.
136
+
137
+ Args:
138
+ f: Function to apply to elements from each sequence.
139
+ *args: Variable number of sequences to map over.
140
+
141
+ Returns:
142
+ list: Results of applying f to corresponding elements.
143
+
144
+ Raises:
145
+ AssertionError: If input sequences have different lengths.
146
+
147
+ Examples:
148
+ >>> safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5, 6])
149
+ [5, 7, 9]
150
+
151
+ >>> safe_map(str.upper, ['a', 'b', 'c'])
152
+ ['A', 'B', 'C']
153
+ """
154
+ args = list(map(list, args))
155
+ n = len(args[0])
156
+ for arg in args[1:]:
157
+ assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
158
+ return list(map(f, *args))
159
+
160
+
161
+ def safe_zip(*args):
162
+ """
163
+ Zip multiple sequences with length checking.
164
+
165
+ Combines corresponding elements from multiple sequences into tuples,
166
+ ensuring all sequences have the same length.
167
+
168
+ Args:
169
+ *args: Variable number of sequences to zip together.
170
+
171
+ Returns:
172
+ list: List of tuples containing corresponding elements.
173
+
174
+ Raises:
175
+ AssertionError: If input sequences have different lengths.
176
+
177
+ Examples:
178
+ >>> safe_zip([1, 2, 3], ['a', 'b', 'c'])
179
+ [(1, 'a'), (2, 'b'), (3, 'c')]
180
+
181
+ >>> safe_zip([1, 2], [3, 4], [5, 6])
182
+ [(1, 3, 5), (2, 4, 6)]
183
+ """
184
+ args = list(map(list, args))
185
+ n = len(args[0])
186
+ for arg in args[1:]:
187
+ assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
188
+ return list(zip(*args))
189
+
190
+
191
+ def unzip2(xys: Iterable[tuple[T1, T2]]) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
192
+ """
193
+ Unzip sequence of length-2 tuples into two tuples.
194
+
195
+ Takes an iterable of 2-tuples and separates them into two tuples
196
+ containing the first and second elements respectively.
197
+
198
+ Args:
199
+ xys: Iterable of 2-tuples to unzip.
200
+
201
+ Returns:
202
+ tuple: A 2-tuple containing:
203
+ - Tuple of all first elements
204
+ - Tuple of all second elements
205
+
206
+ Examples:
207
+ >>> pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
208
+ >>> nums, letters = unzip2(pairs)
209
+ >>> nums
210
+ (1, 2, 3)
211
+ >>> letters
212
+ ('a', 'b', 'c')
213
+
214
+ Notes:
215
+ We deliberately don't use zip(*xys) because it is lazily evaluated,
216
+ is too permissive about inputs, and does not guarantee a length-2 output.
217
+ """
218
+ # Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
219
+ # is too permissive about inputs, and does not guarantee a length-2 output.
220
+ xs: list[T1] = []
221
+ ys: list[T2] = []
222
+ for x, y in xys:
223
+ xs.append(x)
224
+ ys.append(y)
225
+ return tuple(xs), tuple(ys)
226
+
227
+
228
+ def fun_name(fun: Callable):
229
+ """
230
+ Extract the name of a function, handling special cases.
231
+
232
+ Attempts to get the name of a function, with special handling for
233
+ partial functions and fallback for unnamed functions.
234
+
235
+ Args:
236
+ fun: The function to get the name from.
237
+
238
+ Returns:
239
+ str: The function name, or "<unnamed function>" if no name available.
240
+
241
+ Examples:
242
+ >>> def my_function():
243
+ ... pass
244
+ >>> fun_name(my_function)
245
+ 'my_function'
246
+
247
+ >>> from functools import partial
248
+ >>> add = lambda x, y: x + y
249
+ >>> add_one = partial(add, 1)
250
+ >>> fun_name(add_one)
251
+ '<lambda>'
252
+ """
253
+ name = getattr(fun, "__name__", None)
254
+ if name is not None:
255
+ return name
256
+ if isinstance(fun, partial):
257
+ return fun_name(fun.func)
258
+ else:
259
+ return "<unnamed function>"
260
+
261
+
262
+ def wraps(
263
+ wrapped: Callable,
264
+ namestr: str | None = None,
265
+ docstr: str | None = None,
266
+ **kwargs,
267
+ ) -> Callable[[T], T]:
268
+ """
269
+ Enhanced function wrapper with fine-grained control.
270
+
271
+ Like functools.wraps, but provides more control over the name and docstring
272
+ of the resulting function. Useful for creating custom decorators.
273
+
274
+ Args:
275
+ wrapped: The function being wrapped.
276
+ namestr: Optional format string for the wrapper function name.
277
+ Can use {fun} placeholder for the original function name.
278
+ docstr: Optional format string for the wrapper function docstring.
279
+ Can use {fun}, {doc}, and other kwargs as placeholders.
280
+ **kwargs: Additional keyword arguments for format string substitution.
281
+
282
+ Returns:
283
+ Callable: A decorator function that applies the wrapping.
284
+
285
+ Examples:
286
+ >>> def my_decorator(func):
287
+ ... @wraps(func, namestr="decorated_{fun}")
288
+ ... def wrapper(*args, **kwargs):
289
+ ... return func(*args, **kwargs)
290
+ ... return wrapper
291
+
292
+ >>> @my_decorator
293
+ ... def example():
294
+ ... pass
295
+ >>> example.__name__
296
+ 'decorated_example'
297
+ """
298
+
299
+ def wrapper(fun: T) -> T:
300
+ try:
301
+ name = fun_name(wrapped)
302
+ doc = getattr(wrapped, "__doc__", "") or ""
303
+ fun.__dict__.update(getattr(wrapped, "__dict__", {}))
304
+ fun.__annotations__ = getattr(wrapped, "__annotations__", {})
305
+ fun.__name__ = name if namestr is None else namestr.format(fun=name)
306
+ fun.__module__ = getattr(wrapped, "__module__", "<unknown module>")
307
+ fun.__doc__ = (doc if docstr is None
308
+ else docstr.format(fun=name, doc=doc, **kwargs))
309
+ fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__)
310
+ fun.__wrapped__ = wrapped
311
+ except Exception:
312
+ pass
313
+ return fun
314
+
315
+ return wrapper
316
+
317
+
318
+ def to_concrete_aval(aval):
319
+ """
320
+ Convert an abstract value to its concrete representation.
321
+
322
+ Takes an abstract value and attempts to convert it to a concrete value,
323
+ handling JAX Tracer objects appropriately.
324
+
325
+ Args:
326
+ aval: The abstract value to convert.
327
+
328
+ Returns:
329
+ The concrete value representation, or the original aval if already concrete.
330
+
331
+ Examples:
332
+ >>> import jax.numpy as jnp
333
+ >>> arr = jnp.array([1, 2, 3])
334
+ >>> concrete = to_concrete_aval(arr)
335
+ # Returns the concrete array value
336
+ """
337
+ aval = get_aval(aval)
338
+ if isinstance(aval, Tracer):
339
+ return aval.to_concrete_value()
340
+ return aval