brainstate 0.1.10__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {brainstate-0.1.10 → brainstate-0.2.0}/PKG-INFO +34 -17
- {brainstate-0.1.10 → brainstate-0.2.0}/README.md +3 -3
- brainstate-0.2.0/brainstate/__init__.py +169 -0
- brainstate-0.2.0/brainstate/_compatible_import.py +340 -0
- brainstate-0.2.0/brainstate/_compatible_import_test.py +681 -0
- brainstate-0.2.0/brainstate/_deprecation.py +210 -0
- brainstate-0.2.0/brainstate/_deprecation_test.py +2319 -0
- brainstate-0.1.10/brainstate/util/error.py → brainstate-0.2.0/brainstate/_error.py +10 -20
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/_state.py +94 -47
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/_state_test.py +1 -1
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/_utils.py +1 -1
- brainstate-0.2.0/brainstate/environ.py +1495 -0
- brainstate-0.2.0/brainstate/environ_test.py +1223 -0
- brainstate-0.1.10/brainstate/transform.py → brainstate-0.2.0/brainstate/graph/__init__.py +6 -7
- brainstate-0.2.0/brainstate/graph/_node.py +240 -0
- brainstate-0.2.0/brainstate/graph/_node_test.py +589 -0
- brainstate-0.1.10/brainstate/graph/_graph_operation.py → brainstate-0.2.0/brainstate/graph/_operation.py +632 -746
- brainstate-0.2.0/brainstate/graph/_operation_test.py +1147 -0
- brainstate-0.2.0/brainstate/mixin.py +1433 -0
- brainstate-0.2.0/brainstate/mixin_test.py +1017 -0
- brainstate-0.2.0/brainstate/nn/__init__.py +137 -0
- brainstate-0.2.0/brainstate/nn/_activations.py +1100 -0
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/nn/_activations_test.py +109 -86
- brainstate-0.2.0/brainstate/nn/_collective_ops.py +633 -0
- brainstate-0.2.0/brainstate/nn/_collective_ops_test.py +774 -0
- brainstate-0.2.0/brainstate/nn/_common.py +226 -0
- brainstate-0.2.0/brainstate/nn/_common_test.py +154 -0
- brainstate-0.2.0/brainstate/nn/_conv.py +2010 -0
- brainstate-0.2.0/brainstate/nn/_conv_test.py +849 -0
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/nn/_delay.py +15 -28
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/nn/_delay_test.py +25 -20
- brainstate-0.2.0/brainstate/nn/_dropout.py +618 -0
- brainstate-0.2.0/brainstate/nn/_dropout_test.py +477 -0
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/nn/_dynamics.py +14 -90
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/nn/_dynamics_test.py +1 -12
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/nn/_elementwise.py +492 -313
- brainstate-0.2.0/brainstate/nn/_elementwise_test.py +830 -0
- brainstate-0.2.0/brainstate/nn/_embedding.py +408 -0
- brainstate-0.2.0/brainstate/nn/_embedding_test.py +156 -0
- brainstate-0.1.10/brainstate/nn/_fixedprob.py → brainstate-0.2.0/brainstate/nn/_event_fixedprob.py +10 -16
- brainstate-0.1.10/brainstate/nn/_fixedprob_test.py → brainstate-0.2.0/brainstate/nn/_event_fixedprob_test.py +6 -5
- brainstate-0.1.10/brainstate/nn/_linear_mv.py → brainstate-0.2.0/brainstate/nn/_event_linear.py +2 -2
- brainstate-0.1.10/brainstate/nn/_linear_mv_test.py → brainstate-0.2.0/brainstate/nn/_event_linear_test.py +6 -5
- brainstate-0.2.0/brainstate/nn/_exp_euler.py +254 -0
- brainstate-0.2.0/brainstate/nn/_exp_euler_test.py +377 -0
- brainstate-0.2.0/brainstate/nn/_linear.py +744 -0
- brainstate-0.2.0/brainstate/nn/_linear_test.py +475 -0
- brainstate-0.2.0/brainstate/nn/_metrics.py +1070 -0
- brainstate-0.2.0/brainstate/nn/_metrics_test.py +611 -0
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/nn/_module.py +10 -3
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/nn/_module_test.py +1 -1
- brainstate-0.2.0/brainstate/nn/_normalizations.py +1334 -0
- brainstate-0.2.0/brainstate/nn/_normalizations_test.py +699 -0
- brainstate-0.2.0/brainstate/nn/_paddings.py +1020 -0
- brainstate-0.2.0/brainstate/nn/_paddings_test.py +723 -0
- brainstate-0.2.0/brainstate/nn/_poolings.py +2239 -0
- brainstate-0.2.0/brainstate/nn/_poolings_test.py +953 -0
- brainstate-0.1.10/brainstate/nn/_rate_rnns.py → brainstate-0.2.0/brainstate/nn/_rnns.py +446 -54
- brainstate-0.2.0/brainstate/nn/_rnns_test.py +593 -0
- brainstate-0.2.0/brainstate/nn/_utils.py +216 -0
- brainstate-0.2.0/brainstate/nn/_utils_test.py +402 -0
- brainstate-0.1.10/brainstate/init/_random_inits.py → brainstate-0.2.0/brainstate/nn/init.py +301 -45
- brainstate-0.1.10/brainstate/init/_random_inits_test.py → brainstate-0.2.0/brainstate/nn/init_test.py +51 -20
- brainstate-0.2.0/brainstate/random/__init__.py +270 -0
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/random/_rand_funs.py +668 -346
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/random/_rand_funs_test.py +74 -1
- brainstate-0.2.0/brainstate/random/_rand_seed.py +675 -0
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/random/_rand_seed_test.py +1 -1
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate/random/_rand_state.py +601 -393
- brainstate-0.2.0/brainstate/random/_rand_state_test.py +551 -0
- brainstate-0.2.0/brainstate/transform/__init__.py +59 -0
- brainstate-0.2.0/brainstate/transform/_ad_checkpoint.py +176 -0
- {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_ad_checkpoint_test.py +1 -1
- {brainstate-0.1.10/brainstate/augment → brainstate-0.2.0/brainstate/transform}/_autograd.py +360 -113
- {brainstate-0.1.10/brainstate/augment → brainstate-0.2.0/brainstate/transform}/_autograd_test.py +2 -2
- brainstate-0.2.0/brainstate/transform/_conditions.py +316 -0
- {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_conditions_test.py +11 -11
- {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_error_if.py +22 -20
- {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_error_if_test.py +1 -1
- brainstate-0.2.0/brainstate/transform/_eval_shape.py +145 -0
- {brainstate-0.1.10/brainstate/augment → brainstate-0.2.0/brainstate/transform}/_eval_shape_test.py +1 -1
- {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_jit.py +99 -46
- {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_jit_test.py +3 -3
- {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_loop_collect_return.py +219 -80
- {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_loop_collect_return_test.py +1 -1
- {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_loop_no_collection.py +133 -34
- {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_loop_no_collection_test.py +2 -2
- brainstate-0.2.0/brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate-0.2.0/brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate-0.2.0/brainstate/transform/_mapping.py +529 -0
- brainstate-0.2.0/brainstate/transform/_mapping_test.py +194 -0
- {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_progress_bar.py +78 -25
- {brainstate-0.1.10/brainstate/augment → brainstate-0.2.0/brainstate/transform}/_random.py +65 -45
- {brainstate-0.1.10/brainstate/compile → brainstate-0.2.0/brainstate/transform}/_unvmap.py +102 -5
- brainstate-0.2.0/brainstate/transform/_util.py +286 -0
- brainstate-0.2.0/brainstate/typing.py +837 -0
- brainstate-0.2.0/brainstate/typing_test.py +780 -0
- {brainstate-0.1.10/brainstate/random → brainstate-0.2.0/brainstate/util}/__init__.py +12 -9
- brainstate-0.2.0/brainstate/util/_others.py +1025 -0
- brainstate-0.2.0/brainstate/util/_others_test.py +962 -0
- brainstate-0.2.0/brainstate/util/_pretty_pytree.py +1301 -0
- brainstate-0.2.0/brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate-0.1.10/brainstate/util/pretty_repr.py → brainstate-0.2.0/brainstate/util/_pretty_repr.py +161 -27
- brainstate-0.2.0/brainstate/util/_pretty_repr_test.py +696 -0
- brainstate-0.2.0/brainstate/util/filter.py +945 -0
- brainstate-0.2.0/brainstate/util/filter_test.py +912 -0
- brainstate-0.2.0/brainstate/util/struct.py +910 -0
- brainstate-0.2.0/brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate.egg-info/PKG-INFO +34 -17
- brainstate-0.2.0/brainstate.egg-info/SOURCES.txt +114 -0
- brainstate-0.2.0/brainstate.egg-info/requires.txt +31 -0
- {brainstate-0.1.10 → brainstate-0.2.0}/pyproject.toml +50 -14
- brainstate-0.1.10/brainstate/__init__.py +0 -58
- brainstate-0.1.10/brainstate/_compatible_import.py +0 -148
- brainstate-0.1.10/brainstate/augment/__init__.py +0 -30
- brainstate-0.1.10/brainstate/augment/_eval_shape.py +0 -99
- brainstate-0.1.10/brainstate/augment/_mapping.py +0 -1060
- brainstate-0.1.10/brainstate/augment/_mapping_test.py +0 -597
- brainstate-0.1.10/brainstate/compile/__init__.py +0 -38
- brainstate-0.1.10/brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate-0.1.10/brainstate/compile/_conditions.py +0 -256
- brainstate-0.1.10/brainstate/compile/_make_jaxpr.py +0 -888
- brainstate-0.1.10/brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate-0.1.10/brainstate/compile/_util.py +0 -147
- brainstate-0.1.10/brainstate/environ.py +0 -563
- brainstate-0.1.10/brainstate/environ_test.py +0 -62
- brainstate-0.1.10/brainstate/functional/__init__.py +0 -27
- brainstate-0.1.10/brainstate/graph/__init__.py +0 -29
- brainstate-0.1.10/brainstate/graph/_graph_node.py +0 -244
- brainstate-0.1.10/brainstate/graph/_graph_node_test.py +0 -73
- brainstate-0.1.10/brainstate/graph/_graph_operation_test.py +0 -563
- brainstate-0.1.10/brainstate/init/__init__.py +0 -26
- brainstate-0.1.10/brainstate/init/_base.py +0 -52
- brainstate-0.1.10/brainstate/init/_generic.py +0 -244
- brainstate-0.1.10/brainstate/init/_regular_inits.py +0 -105
- brainstate-0.1.10/brainstate/init/_regular_inits_test.py +0 -50
- brainstate-0.1.10/brainstate/mixin.py +0 -365
- brainstate-0.1.10/brainstate/mixin_test.py +0 -77
- brainstate-0.1.10/brainstate/nn/__init__.py +0 -135
- brainstate-0.1.10/brainstate/nn/_activations.py +0 -808
- brainstate-0.1.10/brainstate/nn/_collective_ops.py +0 -514
- brainstate-0.1.10/brainstate/nn/_collective_ops_test.py +0 -43
- brainstate-0.1.10/brainstate/nn/_common.py +0 -178
- brainstate-0.1.10/brainstate/nn/_conv.py +0 -501
- brainstate-0.1.10/brainstate/nn/_conv_test.py +0 -238
- brainstate-0.1.10/brainstate/nn/_dropout.py +0 -426
- brainstate-0.1.10/brainstate/nn/_dropout_test.py +0 -100
- brainstate-0.1.10/brainstate/nn/_elementwise_test.py +0 -169
- brainstate-0.1.10/brainstate/nn/_embedding.py +0 -58
- brainstate-0.1.10/brainstate/nn/_exp_euler.py +0 -92
- brainstate-0.1.10/brainstate/nn/_exp_euler_test.py +0 -35
- brainstate-0.1.10/brainstate/nn/_inputs.py +0 -608
- brainstate-0.1.10/brainstate/nn/_linear.py +0 -424
- brainstate-0.1.10/brainstate/nn/_linear_test.py +0 -107
- brainstate-0.1.10/brainstate/nn/_ltp.py +0 -28
- brainstate-0.1.10/brainstate/nn/_neuron.py +0 -705
- brainstate-0.1.10/brainstate/nn/_neuron_test.py +0 -161
- brainstate-0.1.10/brainstate/nn/_normalizations.py +0 -975
- brainstate-0.1.10/brainstate/nn/_normalizations_test.py +0 -73
- brainstate-0.1.10/brainstate/nn/_others.py +0 -46
- brainstate-0.1.10/brainstate/nn/_poolings.py +0 -1177
- brainstate-0.1.10/brainstate/nn/_poolings_test.py +0 -217
- brainstate-0.1.10/brainstate/nn/_projection.py +0 -486
- brainstate-0.1.10/brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate-0.1.10/brainstate/nn/_readout.py +0 -209
- brainstate-0.1.10/brainstate/nn/_readout_test.py +0 -53
- brainstate-0.1.10/brainstate/nn/_stp.py +0 -236
- brainstate-0.1.10/brainstate/nn/_synapse.py +0 -505
- brainstate-0.1.10/brainstate/nn/_synapse_test.py +0 -131
- brainstate-0.1.10/brainstate/nn/_synaptic_projection.py +0 -423
- brainstate-0.1.10/brainstate/nn/_synouts.py +0 -162
- brainstate-0.1.10/brainstate/nn/_synouts_test.py +0 -57
- brainstate-0.1.10/brainstate/nn/_utils.py +0 -89
- brainstate-0.1.10/brainstate/nn/metrics.py +0 -388
- brainstate-0.1.10/brainstate/optim/__init__.py +0 -38
- brainstate-0.1.10/brainstate/optim/_base.py +0 -64
- brainstate-0.1.10/brainstate/optim/_lr_scheduler.py +0 -448
- brainstate-0.1.10/brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate-0.1.10/brainstate/optim/_optax_optimizer.py +0 -152
- brainstate-0.1.10/brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate-0.1.10/brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate-0.1.10/brainstate/random/_rand_seed.py +0 -210
- brainstate-0.1.10/brainstate/random/_random_for_unit.py +0 -52
- brainstate-0.1.10/brainstate/surrogate.py +0 -1957
- brainstate-0.1.10/brainstate/typing.py +0 -304
- brainstate-0.1.10/brainstate/util/__init__.py +0 -50
- brainstate-0.1.10/brainstate/util/caller.py +0 -98
- brainstate-0.1.10/brainstate/util/filter.py +0 -469
- brainstate-0.1.10/brainstate/util/others.py +0 -540
- brainstate-0.1.10/brainstate/util/pretty_pytree.py +0 -945
- brainstate-0.1.10/brainstate/util/pretty_pytree_test.py +0 -159
- brainstate-0.1.10/brainstate/util/pretty_table.py +0 -2954
- brainstate-0.1.10/brainstate/util/scaling.py +0 -258
- brainstate-0.1.10/brainstate/util/struct.py +0 -523
- brainstate-0.1.10/brainstate.egg-info/SOURCES.txt +0 -134
- brainstate-0.1.10/brainstate.egg-info/requires.txt +0 -8
- brainstate-0.1.10/setup.py +0 -96
- {brainstate-0.1.10 → brainstate-0.2.0}/LICENSE +0 -0
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate.egg-info/dependency_links.txt +0 -0
- {brainstate-0.1.10 → brainstate-0.2.0}/brainstate.egg-info/top_level.txt +0 -0
- {brainstate-0.1.10 → brainstate-0.2.0}/setup.cfg +0 -0
@@ -1,14 +1,15 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.
|
4
|
-
Summary: A
|
5
|
-
Home-page: https://github.com/chaobrain/brainstate
|
6
|
-
Author: BrainState Developers
|
3
|
+
Version: 0.2.0
|
4
|
+
Summary: A State-based Transformation System for Brain Modeling.
|
7
5
|
Author-email: BrainState Developers <chao.brain@qq.com>
|
8
6
|
License: Apache-2.0 license
|
9
|
-
Project-URL: homepage,
|
10
|
-
Project-URL: repository,
|
11
|
-
|
7
|
+
Project-URL: homepage, https://github.com/chaobrain/brainstate
|
8
|
+
Project-URL: repository, https://github.com/chaobrain/brainstate
|
9
|
+
Project-URL: Documentation, https://brainstate.readthedocs.io/
|
10
|
+
Project-URL: Source Code, https://github.com/chaobrain/brainstate
|
11
|
+
Project-URL: Bug Tracker, https://github.com/chaobrain/brainstate/issues
|
12
|
+
Keywords: computational neuroscience,brain-inspired computing,brain simulation,brain modeling,spiking neural networks
|
12
13
|
Classifier: Natural Language :: English
|
13
14
|
Classifier: Operating System :: OS Independent
|
14
15
|
Classifier: Development Status :: 4 - Beta
|
@@ -28,20 +29,36 @@ Classifier: Topic :: Software Development :: Libraries
|
|
28
29
|
Requires-Python: >=3.10
|
29
30
|
Description-Content-Type: text/markdown
|
30
31
|
License-File: LICENSE
|
31
|
-
Requires-Dist:
|
32
|
-
Requires-Dist:
|
33
|
-
Requires-Dist:
|
34
|
-
Requires-Dist: brainunit>=0.1.0
|
32
|
+
Requires-Dist: numpy>=1.15
|
33
|
+
Requires-Dist: tqdm
|
34
|
+
Requires-Dist: brainunit
|
35
35
|
Requires-Dist: brainevent
|
36
|
+
Provides-Extra: cpu
|
37
|
+
Requires-Dist: jax[cpu]; extra == "cpu"
|
38
|
+
Requires-Dist: brainunit; extra == "cpu"
|
39
|
+
Requires-Dist: brainevent; extra == "cpu"
|
40
|
+
Provides-Extra: cuda12
|
41
|
+
Requires-Dist: jax[cuda12]; extra == "cuda12"
|
42
|
+
Requires-Dist: brainunit; extra == "cuda12"
|
43
|
+
Requires-Dist: brainevent; extra == "cuda12"
|
44
|
+
Provides-Extra: cuda13
|
45
|
+
Requires-Dist: jax[cuda13]; extra == "cuda13"
|
46
|
+
Requires-Dist: brainunit; extra == "cuda13"
|
47
|
+
Requires-Dist: brainevent; extra == "cuda13"
|
48
|
+
Provides-Extra: tpu
|
49
|
+
Requires-Dist: jax[tpu]; extra == "tpu"
|
50
|
+
Requires-Dist: brainunit; extra == "tpu"
|
51
|
+
Requires-Dist: brainevent; extra == "tpu"
|
36
52
|
Provides-Extra: testing
|
53
|
+
Requires-Dist: absl-py; extra == "testing"
|
37
54
|
Requires-Dist: pytest; extra == "testing"
|
38
|
-
|
39
|
-
|
55
|
+
Requires-Dist: jax; extra == "testing"
|
56
|
+
Requires-Dist: brainunit; extra == "testing"
|
57
|
+
Requires-Dist: brainevent; extra == "testing"
|
40
58
|
Dynamic: license-file
|
41
|
-
Dynamic: requires-python
|
42
59
|
|
43
60
|
|
44
|
-
# A ``State``-based Transformation System for
|
61
|
+
# A ``State``-based Transformation System for Brain Modeling
|
45
62
|
|
46
63
|
|
47
64
|
|
@@ -84,8 +101,8 @@ The official documentation is hosted on Read the Docs: [https://brainstate.readt
|
|
84
101
|
|
85
102
|
|
86
103
|
|
87
|
-
## See also the
|
104
|
+
## See also the ecosystem
|
88
105
|
|
89
|
-
|
106
|
+
``brainstate`` is one part of our brain simulation ecosystem: https://brainmodeling.readthedocs.io/
|
90
107
|
|
91
108
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
|
2
|
-
# A ``State``-based Transformation System for
|
2
|
+
# A ``State``-based Transformation System for Brain Modeling
|
3
3
|
|
4
4
|
|
5
5
|
|
@@ -42,8 +42,8 @@ The official documentation is hosted on Read the Docs: [https://brainstate.readt
|
|
42
42
|
|
43
43
|
|
44
44
|
|
45
|
-
## See also the
|
45
|
+
## See also the ecosystem
|
46
46
|
|
47
|
-
|
47
|
+
``brainstate`` is one part of our brain simulation ecosystem: https://brainmodeling.readthedocs.io/
|
48
48
|
|
49
49
|
|
@@ -0,0 +1,169 @@
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""
|
17
|
+
A ``State``-based Transformation System for Program Compilation and Augmentation
|
18
|
+
"""
|
19
|
+
|
20
|
+
__version__ = "0.2.0"
|
21
|
+
__versio_info__ = (0, 2, 0)
|
22
|
+
|
23
|
+
from . import environ
|
24
|
+
from . import graph
|
25
|
+
from . import mixin
|
26
|
+
from . import nn
|
27
|
+
from . import random
|
28
|
+
from . import transform
|
29
|
+
from . import typing
|
30
|
+
from . import util
|
31
|
+
from ._error import *
|
32
|
+
from ._error import __all__ as _error_all
|
33
|
+
from ._state import *
|
34
|
+
from ._state import __all__ as _state_all
|
35
|
+
|
36
|
+
# Create deprecated module proxies with scoped APIs
|
37
|
+
from ._deprecation import create_deprecated_module_proxy
|
38
|
+
|
39
|
+
# Augment module scope
|
40
|
+
_augment_apis = {
|
41
|
+
'GradientTransform': 'brainstate.transform._autograd',
|
42
|
+
'grad': 'brainstate.transform._autograd',
|
43
|
+
'vector_grad': 'brainstate.transform._autograd',
|
44
|
+
'hessian': 'brainstate.transform._autograd',
|
45
|
+
'jacobian': 'brainstate.transform._autograd',
|
46
|
+
'jacrev': 'brainstate.transform._autograd',
|
47
|
+
'jacfwd': 'brainstate.transform._autograd',
|
48
|
+
'abstract_init': 'brainstate.transform._eval_shape',
|
49
|
+
'vmap': 'brainstate.transform._mapping',
|
50
|
+
'pmap': 'brainstate.transform._mapping',
|
51
|
+
'map': 'brainstate.transform._mapping',
|
52
|
+
'vmap_new_states': 'brainstate.transform._mapping',
|
53
|
+
'restore_rngs': 'brainstate.transform._random',
|
54
|
+
}
|
55
|
+
|
56
|
+
augment = create_deprecated_module_proxy(
|
57
|
+
deprecated_name='brainstate.augment',
|
58
|
+
replacement_module=transform,
|
59
|
+
replacement_name='brainstate.transform',
|
60
|
+
scoped_apis=_augment_apis
|
61
|
+
)
|
62
|
+
|
63
|
+
# Compile module scope
|
64
|
+
_compile_apis = {
|
65
|
+
'checkpoint': 'brainstate.transform._ad_checkpoint',
|
66
|
+
'remat': 'brainstate.transform._ad_checkpoint',
|
67
|
+
'cond': 'brainstate.transform._conditions',
|
68
|
+
'switch': 'brainstate.transform._conditions',
|
69
|
+
'ifelse': 'brainstate.transform._conditions',
|
70
|
+
'jit_error_if': 'brainstate.transform._error_if',
|
71
|
+
'jit': 'brainstate.transform._jit',
|
72
|
+
'scan': 'brainstate.transform._loop_collect_return',
|
73
|
+
'checkpointed_scan': 'brainstate.transform._loop_collect_return',
|
74
|
+
'for_loop': 'brainstate.transform._loop_collect_return',
|
75
|
+
'checkpointed_for_loop': 'brainstate.transform._loop_collect_return',
|
76
|
+
'while_loop': 'brainstate.transform._loop_no_collection',
|
77
|
+
'bounded_while_loop': 'brainstate.transform._loop_no_collection',
|
78
|
+
'StatefulFunction': 'brainstate.transform._make_jaxpr',
|
79
|
+
'make_jaxpr': 'brainstate.transform._make_jaxpr',
|
80
|
+
'ProgressBar': 'brainstate.transform._progress_bar',
|
81
|
+
}
|
82
|
+
|
83
|
+
compile = create_deprecated_module_proxy(
|
84
|
+
deprecated_name='brainstate.compile',
|
85
|
+
replacement_module=transform,
|
86
|
+
replacement_name='brainstate.transform',
|
87
|
+
scoped_apis=_compile_apis
|
88
|
+
)
|
89
|
+
|
90
|
+
# Functional module scope - use direct attribute access from nn module
|
91
|
+
_functional_apis = {
|
92
|
+
'weight_standardization': 'brainstate.nn._normalizations',
|
93
|
+
'clip_grad_norm': 'brainstate.nn._others',
|
94
|
+
'tanh': 'brainstate.nn._activations',
|
95
|
+
'relu': 'brainstate.nn._activations',
|
96
|
+
'squareplus': 'brainstate.nn._activations',
|
97
|
+
'softplus': 'brainstate.nn._activations',
|
98
|
+
'soft_sign': 'brainstate.nn._activations',
|
99
|
+
'sigmoid': 'brainstate.nn._activations',
|
100
|
+
'silu': 'brainstate.nn._activations',
|
101
|
+
'swish': 'brainstate.nn._activations',
|
102
|
+
'log_sigmoid': 'brainstate.nn._activations',
|
103
|
+
'elu': 'brainstate.nn._activations',
|
104
|
+
'leaky_relu': 'brainstate.nn._activations',
|
105
|
+
'hard_tanh': 'brainstate.nn._activations',
|
106
|
+
'celu': 'brainstate.nn._activations',
|
107
|
+
'selu': 'brainstate.nn._activations',
|
108
|
+
'gelu': 'brainstate.nn._activations',
|
109
|
+
'glu': 'brainstate.nn._activations',
|
110
|
+
'logsumexp': 'brainstate.nn._activations',
|
111
|
+
'log_softmax': 'brainstate.nn._activations',
|
112
|
+
'softmax': 'brainstate.nn._activations',
|
113
|
+
'standardize': 'brainstate.nn._activations',
|
114
|
+
'relu6': 'brainstate.nn._activations',
|
115
|
+
'hard_sigmoid': 'brainstate.nn._activations',
|
116
|
+
'sparse_plus': 'brainstate.nn._activations',
|
117
|
+
'hard_silu': 'brainstate.nn._activations',
|
118
|
+
'hard_swish': 'brainstate.nn._activations',
|
119
|
+
'hard_shrink': 'brainstate.nn._activations',
|
120
|
+
'rrelu': 'brainstate.nn._activations',
|
121
|
+
'mish': 'brainstate.nn._activations',
|
122
|
+
'soft_shrink': 'brainstate.nn._activations',
|
123
|
+
'prelu': 'brainstate.nn._activations',
|
124
|
+
'softmin': 'brainstate.nn._activations',
|
125
|
+
'one_hot': 'brainstate.nn._activations',
|
126
|
+
'sparse_sigmoid': 'brainstate.nn._activations',
|
127
|
+
}
|
128
|
+
|
129
|
+
functional = create_deprecated_module_proxy(
|
130
|
+
deprecated_name='brainstate.functional',
|
131
|
+
replacement_module=nn,
|
132
|
+
replacement_name='brainstate.nn',
|
133
|
+
scoped_apis=_functional_apis
|
134
|
+
)
|
135
|
+
|
136
|
+
|
137
|
+
def __getattr__(name):
|
138
|
+
if name in ['surrogate', 'init', 'optim']:
|
139
|
+
import warnings
|
140
|
+
warnings.warn(
|
141
|
+
f"brainstate.{name} module is deprecated and will be removed in a future version. "
|
142
|
+
f"Please use braintools.{name} instead.",
|
143
|
+
DeprecationWarning,
|
144
|
+
stacklevel=2
|
145
|
+
)
|
146
|
+
import braintools
|
147
|
+
return getattr(braintools, name)
|
148
|
+
raise AttributeError(
|
149
|
+
f'module {__name__!r} has no attribute {name!r}'
|
150
|
+
)
|
151
|
+
|
152
|
+
|
153
|
+
__all__ = [
|
154
|
+
'environ',
|
155
|
+
'graph',
|
156
|
+
'mixin',
|
157
|
+
'nn',
|
158
|
+
'random',
|
159
|
+
'transform',
|
160
|
+
'typing',
|
161
|
+
'util',
|
162
|
+
# Deprecated modules
|
163
|
+
'augment',
|
164
|
+
'compile',
|
165
|
+
'functional',
|
166
|
+
]
|
167
|
+
__all__ = __all__ + _state_all + _error_all
|
168
|
+
del _state_all, create_deprecated_module_proxy, _augment_apis, _compile_apis, _functional_apis
|
169
|
+
del _error_all
|
@@ -0,0 +1,340 @@
|
|
1
|
+
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
"""
|
19
|
+
Compatibility layer for JAX version differences.
|
20
|
+
|
21
|
+
This module provides a compatibility layer to handle differences between various
|
22
|
+
versions of JAX, ensuring that BrainState works correctly across different JAX
|
23
|
+
versions. It imports the appropriate modules and functions based on the detected
|
24
|
+
JAX version and provides fallback implementations when necessary.
|
25
|
+
|
26
|
+
Key Features:
|
27
|
+
- Version-aware imports for JAX core functionality
|
28
|
+
- Compatibility wrappers for changed APIs
|
29
|
+
- Fallback implementations for deprecated functions
|
30
|
+
- Type-safe utility functions
|
31
|
+
|
32
|
+
Examples:
|
33
|
+
Basic usage:
|
34
|
+
|
35
|
+
>>> from brainstate._compatible_import import safe_map, safe_zip
|
36
|
+
>>> result = safe_map(lambda x: x * 2, [1, 2, 3])
|
37
|
+
>>> pairs = safe_zip([1, 2, 3], ['a', 'b', 'c'])
|
38
|
+
|
39
|
+
Using JAX core types:
|
40
|
+
|
41
|
+
>>> from brainstate._compatible_import import Primitive, ClosedJaxpr
|
42
|
+
>>> # These imports work across different JAX versions
|
43
|
+
"""
|
44
|
+
|
45
|
+
from contextlib import contextmanager
|
46
|
+
from functools import partial
|
47
|
+
from typing import Iterable, Hashable, TypeVar, Callable
|
48
|
+
|
49
|
+
import jax
|
50
|
+
from jax.core import get_aval, Tracer
|
51
|
+
from saiunit._compatible_import import wrap_init
|
52
|
+
|
53
|
+
__all__ = [
|
54
|
+
'ClosedJaxpr',
|
55
|
+
'Primitive',
|
56
|
+
'extend_axis_env_nd',
|
57
|
+
'jaxpr_as_fun',
|
58
|
+
'get_aval',
|
59
|
+
'Tracer',
|
60
|
+
'to_concrete_aval',
|
61
|
+
'safe_map',
|
62
|
+
'safe_zip',
|
63
|
+
'unzip2',
|
64
|
+
'wraps',
|
65
|
+
'Device',
|
66
|
+
'wrap_init',
|
67
|
+
'Var',
|
68
|
+
'JaxprEqn',
|
69
|
+
'Jaxpr',
|
70
|
+
'Literal',
|
71
|
+
|
72
|
+
'make_iota', 'to_elt', 'BatchTracer', 'BatchTrace',
|
73
|
+
]
|
74
|
+
|
75
|
+
T = TypeVar("T")
|
76
|
+
T1 = TypeVar("T1")
|
77
|
+
T2 = TypeVar("T2")
|
78
|
+
T3 = TypeVar("T3")
|
79
|
+
|
80
|
+
if jax.__version_info__ < (0, 5, 0):
|
81
|
+
from jax.lib.xla_client import Device
|
82
|
+
else:
|
83
|
+
from jax import Device
|
84
|
+
|
85
|
+
if jax.__version_info__ < (0, 7, 1):
|
86
|
+
from jax.interpreters.batching import make_iota, to_elt, BatchTracer, BatchTrace
|
87
|
+
else:
|
88
|
+
from jax._src.interpreters.batching import make_iota, to_elt, BatchTracer, BatchTrace
|
89
|
+
|
90
|
+
if jax.__version_info__ < (0, 4, 38):
|
91
|
+
from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
|
92
|
+
from jax.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
|
93
|
+
else:
|
94
|
+
from jax.extend.core import ClosedJaxpr, Primitive, jaxpr_as_fun
|
95
|
+
from jax.extend.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
|
96
|
+
from jax.core import trace_ctx
|
97
|
+
|
98
|
+
|
99
|
+
@contextmanager
|
100
|
+
def extend_axis_env_nd(name_size_pairs: Iterable[tuple[Hashable, int]]):
|
101
|
+
"""
|
102
|
+
Context manager to temporarily extend the JAX axis environment.
|
103
|
+
|
104
|
+
Extends the current JAX axis environment with new named axes for
|
105
|
+
vectorized computations, then restores the previous environment.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
name_size_pairs: Iterable of (name, size) tuples specifying
|
109
|
+
the named axes to add to the environment.
|
110
|
+
|
111
|
+
Yields:
|
112
|
+
None: Context with extended axis environment.
|
113
|
+
|
114
|
+
Examples:
|
115
|
+
>>> with extend_axis_env_nd([('batch', 32), ('seq', 128)]):
|
116
|
+
... # Code using vectorized operations with named axes
|
117
|
+
... pass
|
118
|
+
"""
|
119
|
+
prev = trace_ctx.axis_env
|
120
|
+
try:
|
121
|
+
trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs))
|
122
|
+
yield
|
123
|
+
finally:
|
124
|
+
trace_ctx.set_axis_env(prev)
|
125
|
+
|
126
|
+
if jax.__version_info__ < (0, 6, 0):
|
127
|
+
from jax.util import safe_map, safe_zip, unzip2, wraps
|
128
|
+
|
129
|
+
else:
|
130
|
+
def safe_map(f, *args):
|
131
|
+
"""
|
132
|
+
Map a function over multiple sequences with length checking.
|
133
|
+
|
134
|
+
Applies a function to corresponding elements from multiple sequences,
|
135
|
+
ensuring all sequences have the same length.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
f: Function to apply to elements from each sequence.
|
139
|
+
*args: Variable number of sequences to map over.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
list: Results of applying f to corresponding elements.
|
143
|
+
|
144
|
+
Raises:
|
145
|
+
AssertionError: If input sequences have different lengths.
|
146
|
+
|
147
|
+
Examples:
|
148
|
+
>>> safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5, 6])
|
149
|
+
[5, 7, 9]
|
150
|
+
|
151
|
+
>>> safe_map(str.upper, ['a', 'b', 'c'])
|
152
|
+
['A', 'B', 'C']
|
153
|
+
"""
|
154
|
+
args = list(map(list, args))
|
155
|
+
n = len(args[0])
|
156
|
+
for arg in args[1:]:
|
157
|
+
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
158
|
+
return list(map(f, *args))
|
159
|
+
|
160
|
+
|
161
|
+
def safe_zip(*args):
|
162
|
+
"""
|
163
|
+
Zip multiple sequences with length checking.
|
164
|
+
|
165
|
+
Combines corresponding elements from multiple sequences into tuples,
|
166
|
+
ensuring all sequences have the same length.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
*args: Variable number of sequences to zip together.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
list: List of tuples containing corresponding elements.
|
173
|
+
|
174
|
+
Raises:
|
175
|
+
AssertionError: If input sequences have different lengths.
|
176
|
+
|
177
|
+
Examples:
|
178
|
+
>>> safe_zip([1, 2, 3], ['a', 'b', 'c'])
|
179
|
+
[(1, 'a'), (2, 'b'), (3, 'c')]
|
180
|
+
|
181
|
+
>>> safe_zip([1, 2], [3, 4], [5, 6])
|
182
|
+
[(1, 3, 5), (2, 4, 6)]
|
183
|
+
"""
|
184
|
+
args = list(map(list, args))
|
185
|
+
n = len(args[0])
|
186
|
+
for arg in args[1:]:
|
187
|
+
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
188
|
+
return list(zip(*args))
|
189
|
+
|
190
|
+
|
191
|
+
def unzip2(xys: Iterable[tuple[T1, T2]]) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
|
192
|
+
"""
|
193
|
+
Unzip sequence of length-2 tuples into two tuples.
|
194
|
+
|
195
|
+
Takes an iterable of 2-tuples and separates them into two tuples
|
196
|
+
containing the first and second elements respectively.
|
197
|
+
|
198
|
+
Args:
|
199
|
+
xys: Iterable of 2-tuples to unzip.
|
200
|
+
|
201
|
+
Returns:
|
202
|
+
tuple: A 2-tuple containing:
|
203
|
+
- Tuple of all first elements
|
204
|
+
- Tuple of all second elements
|
205
|
+
|
206
|
+
Examples:
|
207
|
+
>>> pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
|
208
|
+
>>> nums, letters = unzip2(pairs)
|
209
|
+
>>> nums
|
210
|
+
(1, 2, 3)
|
211
|
+
>>> letters
|
212
|
+
('a', 'b', 'c')
|
213
|
+
|
214
|
+
Notes:
|
215
|
+
We deliberately don't use zip(*xys) because it is lazily evaluated,
|
216
|
+
is too permissive about inputs, and does not guarantee a length-2 output.
|
217
|
+
"""
|
218
|
+
# Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
|
219
|
+
# is too permissive about inputs, and does not guarantee a length-2 output.
|
220
|
+
xs: list[T1] = []
|
221
|
+
ys: list[T2] = []
|
222
|
+
for x, y in xys:
|
223
|
+
xs.append(x)
|
224
|
+
ys.append(y)
|
225
|
+
return tuple(xs), tuple(ys)
|
226
|
+
|
227
|
+
|
228
|
+
def fun_name(fun: Callable):
|
229
|
+
"""
|
230
|
+
Extract the name of a function, handling special cases.
|
231
|
+
|
232
|
+
Attempts to get the name of a function, with special handling for
|
233
|
+
partial functions and fallback for unnamed functions.
|
234
|
+
|
235
|
+
Args:
|
236
|
+
fun: The function to get the name from.
|
237
|
+
|
238
|
+
Returns:
|
239
|
+
str: The function name, or "<unnamed function>" if no name available.
|
240
|
+
|
241
|
+
Examples:
|
242
|
+
>>> def my_function():
|
243
|
+
... pass
|
244
|
+
>>> fun_name(my_function)
|
245
|
+
'my_function'
|
246
|
+
|
247
|
+
>>> from functools import partial
|
248
|
+
>>> add = lambda x, y: x + y
|
249
|
+
>>> add_one = partial(add, 1)
|
250
|
+
>>> fun_name(add_one)
|
251
|
+
'<lambda>'
|
252
|
+
"""
|
253
|
+
name = getattr(fun, "__name__", None)
|
254
|
+
if name is not None:
|
255
|
+
return name
|
256
|
+
if isinstance(fun, partial):
|
257
|
+
return fun_name(fun.func)
|
258
|
+
else:
|
259
|
+
return "<unnamed function>"
|
260
|
+
|
261
|
+
|
262
|
+
def wraps(
|
263
|
+
wrapped: Callable,
|
264
|
+
namestr: str | None = None,
|
265
|
+
docstr: str | None = None,
|
266
|
+
**kwargs,
|
267
|
+
) -> Callable[[T], T]:
|
268
|
+
"""
|
269
|
+
Enhanced function wrapper with fine-grained control.
|
270
|
+
|
271
|
+
Like functools.wraps, but provides more control over the name and docstring
|
272
|
+
of the resulting function. Useful for creating custom decorators.
|
273
|
+
|
274
|
+
Args:
|
275
|
+
wrapped: The function being wrapped.
|
276
|
+
namestr: Optional format string for the wrapper function name.
|
277
|
+
Can use {fun} placeholder for the original function name.
|
278
|
+
docstr: Optional format string for the wrapper function docstring.
|
279
|
+
Can use {fun}, {doc}, and other kwargs as placeholders.
|
280
|
+
**kwargs: Additional keyword arguments for format string substitution.
|
281
|
+
|
282
|
+
Returns:
|
283
|
+
Callable: A decorator function that applies the wrapping.
|
284
|
+
|
285
|
+
Examples:
|
286
|
+
>>> def my_decorator(func):
|
287
|
+
... @wraps(func, namestr="decorated_{fun}")
|
288
|
+
... def wrapper(*args, **kwargs):
|
289
|
+
... return func(*args, **kwargs)
|
290
|
+
... return wrapper
|
291
|
+
|
292
|
+
>>> @my_decorator
|
293
|
+
... def example():
|
294
|
+
... pass
|
295
|
+
>>> example.__name__
|
296
|
+
'decorated_example'
|
297
|
+
"""
|
298
|
+
|
299
|
+
def wrapper(fun: T) -> T:
|
300
|
+
try:
|
301
|
+
name = fun_name(wrapped)
|
302
|
+
doc = getattr(wrapped, "__doc__", "") or ""
|
303
|
+
fun.__dict__.update(getattr(wrapped, "__dict__", {}))
|
304
|
+
fun.__annotations__ = getattr(wrapped, "__annotations__", {})
|
305
|
+
fun.__name__ = name if namestr is None else namestr.format(fun=name)
|
306
|
+
fun.__module__ = getattr(wrapped, "__module__", "<unknown module>")
|
307
|
+
fun.__doc__ = (doc if docstr is None
|
308
|
+
else docstr.format(fun=name, doc=doc, **kwargs))
|
309
|
+
fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__)
|
310
|
+
fun.__wrapped__ = wrapped
|
311
|
+
except Exception:
|
312
|
+
pass
|
313
|
+
return fun
|
314
|
+
|
315
|
+
return wrapper
|
316
|
+
|
317
|
+
|
318
|
+
def to_concrete_aval(aval):
|
319
|
+
"""
|
320
|
+
Convert an abstract value to its concrete representation.
|
321
|
+
|
322
|
+
Takes an abstract value and attempts to convert it to a concrete value,
|
323
|
+
handling JAX Tracer objects appropriately.
|
324
|
+
|
325
|
+
Args:
|
326
|
+
aval: The abstract value to convert.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
The concrete value representation, or the original aval if already concrete.
|
330
|
+
|
331
|
+
Examples:
|
332
|
+
>>> import jax.numpy as jnp
|
333
|
+
>>> arr = jnp.array([1, 2, 3])
|
334
|
+
>>> concrete = to_concrete_aval(arr)
|
335
|
+
# Returns the concrete array value
|
336
|
+
"""
|
337
|
+
aval = get_aval(aval)
|
338
|
+
if isinstance(aval, Tracer):
|
339
|
+
return aval.to_concrete_value()
|
340
|
+
return aval
|