brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. brainstate/__init__.py +2 -4
  2. brainstate/_deprecation_test.py +2 -24
  3. brainstate/_state.py +540 -35
  4. brainstate/_state_test.py +1085 -8
  5. brainstate/graph/_operation.py +1 -5
  6. brainstate/mixin.py +14 -0
  7. brainstate/nn/__init__.py +42 -33
  8. brainstate/nn/_collective_ops.py +2 -0
  9. brainstate/nn/_common_test.py +0 -20
  10. brainstate/nn/_delay.py +1 -1
  11. brainstate/nn/_dropout_test.py +9 -6
  12. brainstate/nn/_dynamics.py +67 -464
  13. brainstate/nn/_dynamics_test.py +0 -14
  14. brainstate/nn/_embedding.py +7 -7
  15. brainstate/nn/_exp_euler.py +9 -9
  16. brainstate/nn/_linear.py +21 -21
  17. brainstate/nn/_module.py +25 -18
  18. brainstate/nn/_normalizations.py +27 -27
  19. brainstate/random/__init__.py +6 -6
  20. brainstate/random/{_rand_funs.py → _fun.py} +1 -1
  21. brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
  22. brainstate/random/_impl.py +672 -0
  23. brainstate/random/{_rand_seed.py → _seed.py} +1 -1
  24. brainstate/random/{_rand_state.py → _state.py} +121 -418
  25. brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
  26. brainstate/transform/__init__.py +6 -9
  27. brainstate/transform/_conditions.py +2 -2
  28. brainstate/transform/_find_state.py +200 -0
  29. brainstate/transform/_find_state_test.py +84 -0
  30. brainstate/transform/_make_jaxpr.py +221 -61
  31. brainstate/transform/_make_jaxpr_test.py +125 -1
  32. brainstate/transform/_mapping.py +287 -209
  33. brainstate/transform/_mapping_test.py +94 -184
  34. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
  35. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
  36. brainstate/transform/_eval_shape.py +0 -145
  37. brainstate/transform/_eval_shape_test.py +0 -38
  38. brainstate/transform/_random.py +0 -171
  39. /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
  40. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  41. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
  42. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/__init__.py CHANGED
@@ -17,8 +17,8 @@
17
17
  A ``State``-based Transformation System for Program Compilation and Augmentation
18
18
  """
19
19
 
20
- __version__ = "0.2.0"
21
- __versio_info__ = (0, 2, 0)
20
+ __version__ = "0.2.2"
21
+ __versio_info__ = (0, 2, 2)
22
22
 
23
23
  from . import environ
24
24
  from . import graph
@@ -45,12 +45,10 @@ _augment_apis = {
45
45
  'jacobian': 'brainstate.transform._autograd',
46
46
  'jacrev': 'brainstate.transform._autograd',
47
47
  'jacfwd': 'brainstate.transform._autograd',
48
- 'abstract_init': 'brainstate.transform._eval_shape',
49
48
  'vmap': 'brainstate.transform._mapping',
50
49
  'pmap': 'brainstate.transform._mapping',
51
50
  'map': 'brainstate.transform._mapping',
52
51
  'vmap_new_states': 'brainstate.transform._mapping',
53
- 'restore_rngs': 'brainstate.transform._random',
54
52
  }
55
53
 
56
54
  augment = create_deprecated_module_proxy(
@@ -51,8 +51,8 @@ class TestDeprecatedAugmentModule(unittest.TestCase):
51
51
  # Check that expected APIs are available
52
52
  expected_apis = [
53
53
  'GradientTransform', 'grad', 'vector_grad', 'hessian', 'jacobian',
54
- 'jacrev', 'jacfwd', 'abstract_init', 'vmap', 'pmap', 'map',
55
- 'vmap_new_states', 'restore_rngs'
54
+ 'jacrev', 'jacfwd', 'vmap', 'pmap', 'map',
55
+ 'vmap_new_states',
56
56
  ]
57
57
 
58
58
  for api in expected_apis:
@@ -1379,12 +1379,10 @@ class TestDeprecatedAugment(unittest.TestCase):
1379
1379
  'jacobian',
1380
1380
  'jacrev',
1381
1381
  'jacfwd',
1382
- 'abstract_init',
1383
1382
  'vmap',
1384
1383
  'pmap',
1385
1384
  'map',
1386
1385
  'vmap_new_states',
1387
- 'restore_rngs',
1388
1386
  ]
1389
1387
 
1390
1388
  for func_name in augment_funcs:
@@ -1503,26 +1501,6 @@ class TestDeprecatedAugment(unittest.TestCase):
1503
1501
  vmap_new_states = brainstate.augment.vmap_new_states
1504
1502
  self.assertIsNotNone(vmap_new_states)
1505
1503
 
1506
- def test_abstract_init(self):
1507
- """Test abstract_init function."""
1508
- with warnings.catch_warnings(record=True):
1509
- warnings.simplefilter("always")
1510
- import brainstate
1511
-
1512
- # Test abstract_init
1513
- abstract_init = brainstate.augment.abstract_init
1514
- self.assertIsNotNone(abstract_init)
1515
-
1516
- def test_restore_rngs(self):
1517
- """Test restore_rngs function."""
1518
- with warnings.catch_warnings(record=True):
1519
- warnings.simplefilter("always")
1520
- import brainstate
1521
-
1522
- # Test restore_rngs
1523
- restore_rngs = brainstate.augment.restore_rngs
1524
- self.assertIsNotNone(restore_rngs)
1525
-
1526
1504
  def test_module_attributes(self):
1527
1505
  """Test module-level attributes."""
1528
1506
  with warnings.catch_warnings(record=True):