brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +2 -4
- brainstate/_deprecation_test.py +2 -24
- brainstate/_state.py +540 -35
- brainstate/_state_test.py +1085 -8
- brainstate/graph/_operation.py +1 -5
- brainstate/mixin.py +14 -0
- brainstate/nn/__init__.py +42 -33
- brainstate/nn/_collective_ops.py +2 -0
- brainstate/nn/_common_test.py +0 -20
- brainstate/nn/_delay.py +1 -1
- brainstate/nn/_dropout_test.py +9 -6
- brainstate/nn/_dynamics.py +67 -464
- brainstate/nn/_dynamics_test.py +0 -14
- brainstate/nn/_embedding.py +7 -7
- brainstate/nn/_exp_euler.py +9 -9
- brainstate/nn/_linear.py +21 -21
- brainstate/nn/_module.py +25 -18
- brainstate/nn/_normalizations.py +27 -27
- brainstate/random/__init__.py +6 -6
- brainstate/random/{_rand_funs.py → _fun.py} +1 -1
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +1 -1
- brainstate/random/{_rand_state.py → _state.py} +121 -418
- brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
- brainstate/transform/__init__.py +6 -9
- brainstate/transform/_conditions.py +2 -2
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_make_jaxpr.py +221 -61
- brainstate/transform/_make_jaxpr_test.py +125 -1
- brainstate/transform/_mapping.py +287 -209
- brainstate/transform/_mapping_test.py +94 -184
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
@@ -17,8 +17,8 @@
|
|
17
17
|
A ``State``-based Transformation System for Program Compilation and Augmentation
|
18
18
|
"""
|
19
19
|
|
20
|
-
__version__ = "0.2.
|
21
|
-
__versio_info__ = (0, 2,
|
20
|
+
__version__ = "0.2.2"
|
21
|
+
__versio_info__ = (0, 2, 2)
|
22
22
|
|
23
23
|
from . import environ
|
24
24
|
from . import graph
|
@@ -45,12 +45,10 @@ _augment_apis = {
|
|
45
45
|
'jacobian': 'brainstate.transform._autograd',
|
46
46
|
'jacrev': 'brainstate.transform._autograd',
|
47
47
|
'jacfwd': 'brainstate.transform._autograd',
|
48
|
-
'abstract_init': 'brainstate.transform._eval_shape',
|
49
48
|
'vmap': 'brainstate.transform._mapping',
|
50
49
|
'pmap': 'brainstate.transform._mapping',
|
51
50
|
'map': 'brainstate.transform._mapping',
|
52
51
|
'vmap_new_states': 'brainstate.transform._mapping',
|
53
|
-
'restore_rngs': 'brainstate.transform._random',
|
54
52
|
}
|
55
53
|
|
56
54
|
augment = create_deprecated_module_proxy(
|
brainstate/_deprecation_test.py
CHANGED
@@ -51,8 +51,8 @@ class TestDeprecatedAugmentModule(unittest.TestCase):
|
|
51
51
|
# Check that expected APIs are available
|
52
52
|
expected_apis = [
|
53
53
|
'GradientTransform', 'grad', 'vector_grad', 'hessian', 'jacobian',
|
54
|
-
'jacrev', 'jacfwd', '
|
55
|
-
'vmap_new_states',
|
54
|
+
'jacrev', 'jacfwd', 'vmap', 'pmap', 'map',
|
55
|
+
'vmap_new_states',
|
56
56
|
]
|
57
57
|
|
58
58
|
for api in expected_apis:
|
@@ -1379,12 +1379,10 @@ class TestDeprecatedAugment(unittest.TestCase):
|
|
1379
1379
|
'jacobian',
|
1380
1380
|
'jacrev',
|
1381
1381
|
'jacfwd',
|
1382
|
-
'abstract_init',
|
1383
1382
|
'vmap',
|
1384
1383
|
'pmap',
|
1385
1384
|
'map',
|
1386
1385
|
'vmap_new_states',
|
1387
|
-
'restore_rngs',
|
1388
1386
|
]
|
1389
1387
|
|
1390
1388
|
for func_name in augment_funcs:
|
@@ -1503,26 +1501,6 @@ class TestDeprecatedAugment(unittest.TestCase):
|
|
1503
1501
|
vmap_new_states = brainstate.augment.vmap_new_states
|
1504
1502
|
self.assertIsNotNone(vmap_new_states)
|
1505
1503
|
|
1506
|
-
def test_abstract_init(self):
|
1507
|
-
"""Test abstract_init function."""
|
1508
|
-
with warnings.catch_warnings(record=True):
|
1509
|
-
warnings.simplefilter("always")
|
1510
|
-
import brainstate
|
1511
|
-
|
1512
|
-
# Test abstract_init
|
1513
|
-
abstract_init = brainstate.augment.abstract_init
|
1514
|
-
self.assertIsNotNone(abstract_init)
|
1515
|
-
|
1516
|
-
def test_restore_rngs(self):
|
1517
|
-
"""Test restore_rngs function."""
|
1518
|
-
with warnings.catch_warnings(record=True):
|
1519
|
-
warnings.simplefilter("always")
|
1520
|
-
import brainstate
|
1521
|
-
|
1522
|
-
# Test restore_rngs
|
1523
|
-
restore_rngs = brainstate.augment.restore_rngs
|
1524
|
-
self.assertIsNotNone(restore_rngs)
|
1525
|
-
|
1526
1504
|
def test_module_attributes(self):
|
1527
1505
|
"""Test module-level attributes."""
|
1528
1506
|
with warnings.catch_warnings(record=True):
|