brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,194 +1,104 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import unittest
18
-
19
- import jax
20
- import jax.numpy as jnp
21
- from jax import vmap
22
- from jax.lax import psum, pmean, pmax
23
-
24
- import brainstate
25
- import brainstate.transform
26
- from brainstate._error import BatchAxisError
27
-
28
-
29
-
30
- class TestMap(unittest.TestCase):
31
- def test_map(self):
32
- for dim in [(10,), (10, 10), (10, 10, 10)]:
33
- x = brainstate.random.rand(*dim)
34
- r1 = brainstate.transform.map(lambda a: a + 1, x, batch_size=None)
35
- r2 = brainstate.transform.map(lambda a: a + 1, x, batch_size=2)
36
- r3 = brainstate.transform.map(lambda a: a + 1, x, batch_size=4)
37
- r4 = brainstate.transform.map(lambda a: a + 1, x, batch_size=5)
38
- true_r = x + 1
39
-
40
- self.assertTrue(jnp.allclose(r1, true_r))
41
- self.assertTrue(jnp.allclose(r2, true_r))
42
- self.assertTrue(jnp.allclose(r3, true_r))
43
- self.assertTrue(jnp.allclose(r4, true_r))
44
-
45
-
46
- class TestAxisName:
47
- def test1(self):
48
- def compute_stats_with_axis_name(x):
49
- """Compute statistics using named axis operations"""
50
- # Sum across the named axis 'batch'
51
- total_sum = psum(x, axis_name='batch')
52
-
53
- # Mean across the named axis 'batch'
54
- mean_val = pmean(x, axis_name='batch')
55
-
56
- # Max across the named axis 'batch'
57
- max_val = pmax(x, axis_name='batch')
58
-
59
- return {
60
- 'sum': total_sum,
61
- 'mean': mean_val,
62
- 'max': max_val,
63
- 'original': x
64
- }
65
-
66
- batch_data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
67
- print("Input batch data:", batch_data)
68
-
69
- # vmap with axis name 'batch'
70
- vectorized_stats_jax = jax.jit(vmap(compute_stats_with_axis_name, axis_name='batch'))
71
- result_jax = vectorized_stats_jax(batch_data)
72
-
73
- # vmap with axis name 'batch'
74
- vectorized_stats = brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
75
- result = vectorized_stats(batch_data)
76
-
77
- # vmap with axis name 'batch'
78
- vectorized_stats_v2 = brainstate.transform.jit(
79
- brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
80
- )
81
- result_v2 = vectorized_stats_v2(batch_data)
82
-
83
- for key in result_jax.keys():
84
- print(f" {key}: {result_jax[key]}")
85
- assert jnp.allclose(result_jax[key], result[key]), f"Mismatch in {key}"
86
- assert jnp.allclose(result_jax[key], result_v2[key]), f"Mismatch in {key}"
87
-
88
- def test_nested_vmap(self):
89
- def nested_computation(x):
90
- """Computation with multiple named axes"""
91
- # Sum over 'inner' axis, then mean over 'outer' axis
92
- inner_sum = psum(x, axis_name='inner')
93
- outer_mean = pmean(inner_sum, axis_name='outer')
94
- return outer_mean
95
-
96
- # Create 2D batch data
97
- data_2d = jnp.arange(12.0).reshape(3, 4) # Shape: [outer_batch=3, inner_batch=4]
98
- print("Input 2D data shape:", data_2d.shape)
99
- print("Input 2D data:\n", data_2d)
100
-
101
- # Nested vmap: first over inner dimension, then outer dimension
102
- inner_vmap = vmap(nested_computation, axis_name='inner')
103
- nested_vmap = vmap(inner_vmap, axis_name='outer')
104
-
105
- result_2d = nested_vmap(data_2d)
106
- print("Result after nested vmap:", result_2d)
107
-
108
- inner_vmap_bst = brainstate.transform.vmap(nested_computation, axis_name='inner')
109
- nested_vmap_bst = brainstate.transform.vmap(inner_vmap_bst, axis_name='outer')
110
- result_2d_bst = nested_vmap_bst(data_2d)
111
- print("Result after nested vmap:", result_2d_bst)
112
-
113
- assert jnp.allclose(result_2d, result_2d_bst)
114
-
115
- def _gradient_averaging_simulation_bst(self):
116
- def loss_function(params, x, y):
117
- """Simple quadratic loss"""
118
- pred = params * x
119
- return (pred - y) ** 2
120
-
121
- def compute_gradients_with_averaging(params, batch_x, batch_y):
122
- """Compute gradients and average them across the batch"""
123
- # Compute per-sample gradients
124
- grad_fn = jax.grad(loss_function, argnums=0)
125
- per_sample_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
126
-
127
- # Average gradients across batch using named axis
128
- def average_grads(grads):
129
- return pmean(grads, axis_name='batch')
130
-
131
- # Apply averaging with named axis
132
- averaged_grads = vmap(average_grads, axis_name='batch')(per_sample_grads)
133
- return averaged_grads
134
-
135
- # Example data
136
- params = 2.0
137
- batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
138
- batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
139
-
140
- print("Parameters:", params)
141
- print("Batch X:", batch_x)
142
- print("Batch Y:", batch_y)
143
-
144
- # Compute individual gradients first
145
- grad_fn = jax.grad(loss_function, argnums=0)
146
- individual_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
147
- print("Individual gradients:", individual_grads)
148
-
149
- # Now compute averaged gradients using axis names
150
- averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
151
- print("Averaged gradients:", averaged_grads)
152
-
153
- return individual_grads, averaged_grads
154
-
155
- def _gradient_averaging_simulation_jax(self):
156
- def loss_function(params, x, y):
157
- """Simple quadratic loss"""
158
- pred = params * x
159
- return (pred - y) ** 2
160
-
161
- def compute_gradients_with_averaging(params, batch_x, batch_y):
162
- """Compute gradients and average them across the batch"""
163
- # Compute per-sample gradients
164
- grad_fn = jax.grad(loss_function, argnums=0)
165
- per_sample_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
166
-
167
- # Average gradients across batch using named axis
168
- def average_grads(grads):
169
- return pmean(grads, axis_name='batch')
170
-
171
- # Apply averaging with named axis
172
- averaged_grads = brainstate.transform.vmap(average_grads, axis_name='batch')(per_sample_grads)
173
- return averaged_grads
174
-
175
- # Example data
176
- params = 2.0
177
- batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
178
- batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
179
-
180
- print("Parameters:", params)
181
- print("Batch X:", batch_x)
182
- print("Batch Y:", batch_y)
183
-
184
- # Compute individual gradients first
185
- grad_fn = jax.grad(loss_function, argnums=0)
186
- individual_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
187
- print("Individual gradients:", individual_grads)
188
-
189
- # Now compute averaged gradients using axis names
190
- averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
191
- print("Averaged gradients:", averaged_grads)
192
-
193
- return individual_grads, averaged_grads
194
-
1
+ import unittest
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+
6
+ import brainstate as bst
7
+ from brainstate.transform import StatefulMapping, vmap, vmap_new_states, pmap, map as bst_map
8
+ from brainstate.util import filter as state_filter
9
+
10
+
11
+ class TestMap(unittest.TestCase):
12
+ def test_map_matches_vectorized(self):
13
+ xs = jnp.arange(6.0).reshape(6, 1)
14
+
15
+ def fn(x):
16
+ return x + 1.0
17
+
18
+ expected = jax.vmap(fn)(xs)
19
+ result = bst_map(fn, xs)
20
+ self.assertTrue(jnp.allclose(result, expected))
21
+
22
+ def test_map_multiple_inputs_and_batch_size(self):
23
+ xs = jnp.arange(5.0)
24
+ ys = jnp.ones_like(xs) * 2.0
25
+
26
+ def fn(a, b):
27
+ return a * a + b
28
+
29
+ expected = jax.vmap(fn)(xs, ys)
30
+ result = bst_map(fn, xs, ys, batch_size=2)
31
+ self.assertTrue(jnp.allclose(result, expected))
32
+
33
+
34
+ class TestVmapIntegration(unittest.TestCase):
35
+ def test_decorator_batched_stateful_function(self):
36
+ counter = bst.ShortTermState(jnp.zeros(3))
37
+
38
+ @vmap(
39
+ in_axes=0,
40
+ out_axes=0,
41
+ state_in_axes={0: state_filter.OfType(bst.ShortTermState)},
42
+ state_out_axes={0: state_filter.OfType(bst.ShortTermState)},
43
+ )
44
+ def accumulate(x):
45
+ counter.value = counter.value + x
46
+ return counter.value
47
+
48
+ xs = jnp.asarray([1.0, 2.0, 3.0])
49
+ result = accumulate(xs)
50
+ self.assertTrue(jnp.allclose(result, xs))
51
+ self.assertTrue(jnp.allclose(counter.value, xs))
52
+
53
+ def test_vmap_partial_returns_stateful_mapping(self):
54
+ builder = vmap(in_axes=0, out_axes=0)
55
+
56
+ def fn(x):
57
+ return x * 2.0
58
+
59
+ mapped = builder(fn)
60
+ self.assertIsInstance(mapped, StatefulMapping)
61
+ xs = jnp.arange(3.0)
62
+ self.assertTrue(jnp.allclose(mapped(xs), xs * 2.0))
63
+
64
+
65
+ class TestVmapNewStates(unittest.TestCase):
66
+ def test_new_states_are_vectorized(self):
67
+ @vmap_new_states(in_axes=0, out_axes=0)
68
+ def build(x):
69
+ scratch = bst.ShortTermState(jnp.array(0.0), tag='scratch')
70
+ scratch.value = scratch.value + x
71
+ return scratch.value
72
+
73
+ xs = jnp.arange(4.0)
74
+ result_first = build(xs)
75
+ result_second = build(xs)
76
+ self.assertTrue(jnp.allclose(result_first, xs))
77
+ self.assertTrue(jnp.allclose(result_second, xs))
78
+
79
+
80
+ class TestPmapIntegration(unittest.TestCase):
81
+ @unittest.skipIf(jax.local_device_count() < 2, "Requires at least 2 devices")
82
+ def test_pmap_stateful_execution(self):
83
+ param = bst.ParamState(jnp.ones((4,)))
84
+
85
+ @pmap(
86
+ in_axes=0,
87
+ out_axes=0,
88
+ axis_name='devices',
89
+ state_in_axes={0: state_filter.OfType(bst.ParamState)},
90
+ state_out_axes={0: state_filter.OfType(bst.ParamState)},
91
+ )
92
+ def update(delta):
93
+ param.value = param.value + delta
94
+ return param.value
95
+
96
+ device_count = jax.local_device_count()
97
+ deltas = jnp.arange(device_count * 4.0, dtype=param.value.dtype).reshape(device_count, 4)
98
+ updated = update(deltas)
99
+ self.assertEqual(updated.shape, (device_count, 4))
100
+ self.assertTrue(jnp.all(updated >= 1.0))
101
+
102
+
103
+ if __name__ == "__main__":
104
+ unittest.main()