brainstate 0.1.1__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +3 -0
  3. brainstate/_state.py +1 -1
  4. brainstate/augment/_autograd_test.py +132 -133
  5. brainstate/augment/_eval_shape_test.py +7 -9
  6. brainstate/augment/_mapping_test.py +75 -76
  7. brainstate/compile/_ad_checkpoint_test.py +6 -8
  8. brainstate/compile/_conditions_test.py +35 -36
  9. brainstate/compile/_error_if_test.py +10 -13
  10. brainstate/compile/_loop_collect_return_test.py +7 -9
  11. brainstate/compile/_loop_no_collection_test.py +7 -8
  12. brainstate/compile/_make_jaxpr.py +29 -14
  13. brainstate/compile/_make_jaxpr_test.py +20 -20
  14. brainstate/functional/_activations_test.py +61 -61
  15. brainstate/graph/_graph_node_test.py +16 -18
  16. brainstate/graph/_graph_operation_test.py +154 -156
  17. brainstate/init/_random_inits_test.py +20 -21
  18. brainstate/init/_regular_inits_test.py +4 -5
  19. brainstate/nn/_collective_ops_test.py +8 -8
  20. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
  21. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
  22. brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
  23. brainstate/nn/_dyn_impl/_readout_test.py +9 -10
  24. brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
  25. brainstate/nn/_dynamics/_synouts_test.py +4 -5
  26. brainstate/nn/_elementwise/_dropout_test.py +9 -9
  27. brainstate/nn/_elementwise/_elementwise_test.py +57 -59
  28. brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
  29. brainstate/nn/_event/_linear_mv_test.py +0 -1
  30. brainstate/nn/_exp_euler_test.py +5 -6
  31. brainstate/nn/_interaction/_conv_test.py +31 -33
  32. brainstate/nn/_interaction/_linear_test.py +15 -17
  33. brainstate/nn/_interaction/_normalizations_test.py +10 -12
  34. brainstate/nn/_interaction/_poolings_test.py +19 -21
  35. brainstate/nn/_module_test.py +34 -37
  36. brainstate/optim/_lr_scheduler_test.py +3 -3
  37. brainstate/optim/_optax_optimizer_test.py +8 -9
  38. brainstate/random/_rand_funs_test.py +183 -184
  39. brainstate/random/_rand_seed_test.py +10 -12
  40. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/METADATA +1 -1
  41. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/RECORD +44 -44
  42. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
  43. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
  44. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -14,27 +14,25 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from __future__ import annotations
18
-
19
17
  import unittest
20
18
 
21
- import brainstate as bst
19
+ import brainstate
22
20
 
23
21
 
24
22
  class TestEvalShape(unittest.TestCase):
25
23
  def test1(self):
26
- class MLP(bst.nn.Module):
24
+ class MLP(brainstate.nn.Module):
27
25
  def __init__(self, n_in, n_mid, n_out):
28
26
  super().__init__()
29
- self.dense1 = bst.nn.Linear(n_in, n_mid)
30
- self.dense2 = bst.nn.Linear(n_mid, n_out)
27
+ self.dense1 = brainstate.nn.Linear(n_in, n_mid)
28
+ self.dense2 = brainstate.nn.Linear(n_mid, n_out)
31
29
 
32
30
  def __call__(self, x):
33
31
  x = self.dense1(x)
34
- x = bst.functional.relu(x)
32
+ x = brainstate.functional.relu(x)
35
33
  x = self.dense2(x)
36
34
  return x
37
35
 
38
- r = bst.augment.abstract_init(lambda: MLP(1, 2, 3))
36
+ r = brainstate.augment.abstract_init(lambda: MLP(1, 2, 3))
39
37
  print(r)
40
- print(bst.random.DEFAULT)
38
+ print(brainstate.random.DEFAULT)
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
@@ -21,7 +20,7 @@ import jax
21
20
  import jax.numpy as jnp
22
21
  import numpy as np
23
22
 
24
- import brainstate as bst
23
+ import brainstate
25
24
  import brainstate.augment
26
25
  from brainstate.augment._mapping import BatchAxisError
27
26
  from brainstate.augment._mapping import _remove_axis
@@ -29,29 +28,29 @@ from brainstate.augment._mapping import _remove_axis
29
28
 
30
29
  class TestVmap(unittest.TestCase):
31
30
  def test_vmap_1(self):
32
- class Model(bst.nn.Module):
31
+ class Model(brainstate.nn.Module):
33
32
  def __init__(self):
34
33
  super().__init__()
35
34
 
36
- self.a = bst.State(bst.random.randn(5))
37
- self.b = bst.State(bst.random.randn(5))
35
+ self.a = brainstate.State(brainstate.random.randn(5))
36
+ self.b = brainstate.State(brainstate.random.randn(5))
38
37
 
39
38
  def __call__(self, *args, **kwargs):
40
39
  return self.a.value * self.b.value
41
40
 
42
41
  model = Model()
43
42
  r1 = model.a.value * model.b.value
44
- r2 = bst.augment.vmap(model, in_states=model.states())()
43
+ r2 = brainstate.augment.vmap(model, in_states=model.states())()
45
44
  self.assertTrue(jnp.allclose(r1, r2))
46
45
 
47
46
  def test_vmap_2(self):
48
- class Model(bst.nn.Module):
47
+ class Model(brainstate.nn.Module):
49
48
  def __init__(self):
50
49
  super().__init__()
51
50
 
52
- self.a = bst.ShortTermState(bst.random.randn(5))
53
- self.b = bst.ShortTermState(bst.random.randn(5))
54
- self.c = bst.State(bst.random.randn(1))
51
+ self.a = brainstate.ShortTermState(brainstate.random.randn(5))
52
+ self.b = brainstate.ShortTermState(brainstate.random.randn(5))
53
+ self.c = brainstate.State(brainstate.random.randn(1))
55
54
 
56
55
  def __call__(self, *args, **kwargs):
57
56
  self.c.value = self.a.value * self.b.value
@@ -59,104 +58,104 @@ class TestVmap(unittest.TestCase):
59
58
 
60
59
  model = Model()
61
60
  with self.assertRaises(BatchAxisError):
62
- r2 = bst.augment.vmap(model, in_states=model.states(bst.ShortTermState))()
61
+ r2 = brainstate.augment.vmap(model, in_states=model.states(brainstate.ShortTermState))()
63
62
 
64
63
  model = Model()
65
- r2 = bst.augment.vmap(model, in_states=model.states(bst.ShortTermState), out_states=model.c)()
64
+ r2 = brainstate.augment.vmap(model, in_states=model.states(brainstate.ShortTermState), out_states=model.c)()
66
65
 
67
66
  def test_vmap_3(self):
68
- class Model(bst.nn.Module):
67
+ class Model(brainstate.nn.Module):
69
68
  def __init__(self):
70
69
  super().__init__()
71
70
 
72
- self.a = bst.State(bst.random.randn(5))
73
- self.b = bst.State(bst.random.randn(5))
71
+ self.a = brainstate.State(brainstate.random.randn(5))
72
+ self.b = brainstate.State(brainstate.random.randn(5))
74
73
 
75
74
  def __call__(self, *args, **kwargs):
76
75
  return self.a.value * self.b.value
77
76
 
78
77
  model = Model()
79
78
  with self.assertRaises(BatchAxisError):
80
- r2 = bst.augment.vmap(model, in_states=model.states(), out_states={1: model.states()})()
79
+ r2 = brainstate.augment.vmap(model, in_states=model.states(), out_states={1: model.states()})()
81
80
 
82
81
  def test_vmap_with_random(self):
83
- class Model(bst.nn.Module):
82
+ class Model(brainstate.nn.Module):
84
83
  def __init__(self):
85
84
  super().__init__()
86
85
 
87
- self.a = bst.ShortTermState(bst.random.randn(5))
88
- self.b = bst.ShortTermState(bst.random.randn(5))
89
- self.c = bst.State(bst.random.randn(1))
86
+ self.a = brainstate.ShortTermState(brainstate.random.randn(5))
87
+ self.b = brainstate.ShortTermState(brainstate.random.randn(5))
88
+ self.c = brainstate.State(brainstate.random.randn(1))
90
89
 
91
90
  def __call__(self, key):
92
- bst.random.set_key(key)
91
+ brainstate.random.set_key(key)
93
92
  self.c.value = self.a.value * self.b.value
94
- return self.c.value + bst.random.randn(1)
93
+ return self.c.value + brainstate.random.randn(1)
95
94
 
96
95
  model = Model()
97
- r2 = bst.augment.vmap(
96
+ r2 = brainstate.augment.vmap(
98
97
  model,
99
- in_states=model.states(bst.ShortTermState),
98
+ in_states=model.states(brainstate.ShortTermState),
100
99
  out_states=model.c
101
100
  )(
102
- bst.random.split_key(5)
101
+ brainstate.random.split_key(5)
103
102
  )
104
- print(bst.random.DEFAULT)
103
+ print(brainstate.random.DEFAULT)
105
104
 
106
105
  def test_vmap_with_random_v3(self):
107
- class Model(bst.nn.Module):
106
+ class Model(brainstate.nn.Module):
108
107
  def __init__(self):
109
108
  super().__init__()
110
109
 
111
- self.a = bst.ShortTermState(bst.random.randn(5))
112
- self.b = bst.ShortTermState(bst.random.randn(5))
113
- self.c = bst.State(bst.random.randn(1))
110
+ self.a = brainstate.ShortTermState(brainstate.random.randn(5))
111
+ self.b = brainstate.ShortTermState(brainstate.random.randn(5))
112
+ self.c = brainstate.State(brainstate.random.randn(1))
114
113
 
115
114
  def __call__(self):
116
115
  self.c.value = self.a.value * self.b.value
117
- return self.c.value + bst.random.randn(1)
116
+ return self.c.value + brainstate.random.randn(1)
118
117
 
119
118
  model = Model()
120
- r2 = bst.augment.vmap(
119
+ r2 = brainstate.augment.vmap(
121
120
  model,
122
- in_states=model.states(bst.ShortTermState),
121
+ in_states=model.states(brainstate.ShortTermState),
123
122
  out_states=model.c
124
123
  )()
125
- print(bst.random.DEFAULT)
124
+ print(brainstate.random.DEFAULT)
126
125
 
127
126
  def test_vmap_with_random_2(self):
128
- class Model(bst.nn.Module):
127
+ class Model(brainstate.nn.Module):
129
128
  def __init__(self):
130
129
  super().__init__()
131
130
 
132
- self.a = bst.ShortTermState(bst.random.randn(5))
133
- self.b = bst.ShortTermState(bst.random.randn(5))
134
- self.c = bst.State(bst.random.randn(1))
135
- self.rng = bst.random.RandomState(1)
131
+ self.a = brainstate.ShortTermState(brainstate.random.randn(5))
132
+ self.b = brainstate.ShortTermState(brainstate.random.randn(5))
133
+ self.c = brainstate.State(brainstate.random.randn(1))
134
+ self.rng = brainstate.random.RandomState(1)
136
135
 
137
136
  def __call__(self, key):
138
137
  self.rng.set_key(key)
139
138
  self.c.value = self.a.value * self.b.value
140
- return self.c.value + bst.random.randn(1)
139
+ return self.c.value + brainstate.random.randn(1)
141
140
 
142
141
  model = Model()
143
- r2 = bst.augment.vmap(
142
+ r2 = brainstate.augment.vmap(
144
143
  model,
145
- in_states=model.states(bst.ShortTermState),
144
+ in_states=model.states(brainstate.ShortTermState),
146
145
  out_states=model.c
147
146
  )(
148
- bst.random.split_key(5)
147
+ brainstate.random.split_key(5)
149
148
  )
150
149
 
151
150
  def test_vmap_input(self):
152
- model = bst.nn.Linear(2, 3)
151
+ model = brainstate.nn.Linear(2, 3)
153
152
  print(id(model), id(model.weight))
154
153
  model_id = id(model)
155
154
  weight_id = id(model.weight)
156
155
 
157
156
  x = jnp.ones((5, 2))
158
157
 
159
- @bst.augment.vmap
158
+ @brainstate.augment.vmap
160
159
  def forward(x):
161
160
  self.assertTrue(id(model) == model_id)
162
161
  self.assertTrue(id(model.weight) == weight_id)
@@ -169,39 +168,39 @@ class TestVmap(unittest.TestCase):
169
168
  print(model.weight.value)
170
169
 
171
170
  def test_vmap_states_and_input_1(self):
172
- gru = bst.nn.GRUCell(2, 3)
171
+ gru = brainstate.nn.GRUCell(2, 3)
173
172
  gru.init_state(5)
174
173
 
175
- @bst.augment.vmap(in_states=gru.states(bst.HiddenState))
174
+ @brainstate.augment.vmap(in_states=gru.states(brainstate.HiddenState))
176
175
  def forward(x):
177
176
  return gru(x)
178
177
 
179
- xs = bst.random.randn(5, 2)
178
+ xs = brainstate.random.randn(5, 2)
180
179
  y = forward(xs)
181
180
  self.assertTrue(y.shape == (5, 3))
182
181
 
183
182
  def test_vmap_jit(self):
184
- class Foo(bst.nn.Module):
183
+ class Foo(brainstate.nn.Module):
185
184
  def __init__(self):
186
185
  super().__init__()
187
- self.a = bst.ParamState(jnp.arange(4))
188
- self.b = bst.ShortTermState(jnp.arange(4))
186
+ self.a = brainstate.ParamState(jnp.arange(4))
187
+ self.b = brainstate.ShortTermState(jnp.arange(4))
189
188
 
190
189
  def __call__(self):
191
190
  self.b.value = self.a.value * self.b.value
192
191
 
193
192
  foo = Foo()
194
193
 
195
- @bst.augment.vmap(in_states=foo.states())
194
+ @brainstate.augment.vmap(in_states=foo.states())
196
195
  def mul():
197
196
  foo()
198
197
 
199
- @bst.compile.jit
198
+ @brainstate.compile.jit
200
199
  def mul_jit(inp):
201
200
  mul()
202
201
  foo.a.value += inp
203
202
 
204
- with bst.StateTraceStack() as trace:
203
+ with brainstate.StateTraceStack() as trace:
205
204
  mul_jit(1.)
206
205
 
207
206
  print(foo.a.value)
@@ -219,27 +218,27 @@ class TestVmap(unittest.TestCase):
219
218
  print(trace.get_read_states())
220
219
 
221
220
  def test_vmap_jit_2(self):
222
- class Foo(bst.nn.Module):
221
+ class Foo(brainstate.nn.Module):
223
222
  def __init__(self):
224
223
  super().__init__()
225
- self.a = bst.ParamState(jnp.arange(4))
226
- self.b = bst.ShortTermState(jnp.arange(4))
224
+ self.a = brainstate.ParamState(jnp.arange(4))
225
+ self.b = brainstate.ShortTermState(jnp.arange(4))
227
226
 
228
227
  def __call__(self):
229
228
  self.b.value = self.a.value * self.b.value
230
229
 
231
230
  foo = Foo()
232
231
 
233
- @bst.augment.vmap(in_states=foo.states())
232
+ @brainstate.augment.vmap(in_states=foo.states())
234
233
  def mul():
235
234
  foo()
236
235
 
237
- @bst.compile.jit
236
+ @brainstate.compile.jit
238
237
  def mul_jit(inp):
239
238
  mul()
240
239
  foo.b.value += inp
241
240
 
242
- with bst.StateTraceStack() as trace:
241
+ with brainstate.StateTraceStack() as trace:
243
242
  mul_jit(1.)
244
243
 
245
244
  print(foo.a.value)
@@ -258,9 +257,9 @@ class TestVmap(unittest.TestCase):
258
257
 
259
258
  def test_auto_rand_key_split(self):
260
259
  def f():
261
- return bst.random.rand(1)
260
+ return brainstate.random.rand(1)
262
261
 
263
- res = bst.augment.vmap(f, axis_size=10)()
262
+ res = brainstate.augment.vmap(f, axis_size=10)()
264
263
  self.assertTrue(jnp.all(~(res[0] == res[1:])))
265
264
 
266
265
  res2 = jax.vmap(f, axis_size=10)()
@@ -278,17 +277,17 @@ class TestVmap(unittest.TestCase):
278
277
  self.assertTrue(jnp.allclose(r, r2))
279
278
 
280
279
  def test_vmap_init(self):
281
- class Foo(bst.nn.Module):
280
+ class Foo(brainstate.nn.Module):
282
281
  def __init__(self):
283
282
  super().__init__()
284
- self.a = bst.ParamState(jnp.arange(4))
285
- self.b = bst.ShortTermState(jnp.arange(4))
283
+ self.a = brainstate.ParamState(jnp.arange(4))
284
+ self.b = brainstate.ShortTermState(jnp.arange(4))
286
285
 
287
286
  def init_state_v1(self, *args, **kwargs):
288
- self.c = bst.State(jnp.arange(4))
287
+ self.c = brainstate.State(jnp.arange(4))
289
288
 
290
289
  def init_state_v2(self):
291
- self.d = bst.State(self.c.value * 2.)
290
+ self.d = brainstate.State(self.c.value * 2.)
292
291
 
293
292
  foo = Foo()
294
293
 
@@ -318,11 +317,11 @@ class TestVmap(unittest.TestCase):
318
317
  class TestMap(unittest.TestCase):
319
318
  def test_map(self):
320
319
  for dim in [(10,), (10, 10), (10, 10, 10)]:
321
- x = bst.random.rand(*dim)
322
- r1 = bst.augment.map(lambda a: a + 1, x, batch_size=None)
323
- r2 = bst.augment.map(lambda a: a + 1, x, batch_size=2)
324
- r3 = bst.augment.map(lambda a: a + 1, x, batch_size=4)
325
- r4 = bst.augment.map(lambda a: a + 1, x, batch_size=5)
320
+ x = brainstate.random.rand(*dim)
321
+ r1 = brainstate.augment.map(lambda a: a + 1, x, batch_size=None)
322
+ r2 = brainstate.augment.map(lambda a: a + 1, x, batch_size=2)
323
+ r3 = brainstate.augment.map(lambda a: a + 1, x, batch_size=4)
324
+ r4 = brainstate.augment.map(lambda a: a + 1, x, batch_size=5)
326
325
  true_r = x + 1
327
326
 
328
327
  self.assertTrue(jnp.allclose(r1, true_r))
@@ -406,7 +405,7 @@ class TestVMAPNewStatesEdgeCases(unittest.TestCase):
406
405
  foo = brainstate.nn.LIF(3)
407
406
  # Testing that axis_size of 0 raises an error.
408
407
  with self.assertRaises(ValueError):
409
- @bst.augment.vmap_new_states(state_tag='new1', axis_size=0)
408
+ @brainstate.augment.vmap_new_states(state_tag='new1', axis_size=0)
410
409
  def faulty_init():
411
410
  foo.init_state()
412
411
 
@@ -417,7 +416,7 @@ class TestVMAPNewStatesEdgeCases(unittest.TestCase):
417
416
  foo = brainstate.nn.LIF(3)
418
417
  # Testing that a negative axis_size raises an error.
419
418
  with self.assertRaises(ValueError):
420
- @bst.augment.vmap_new_states(state_tag='new1', axis_size=-3)
419
+ @brainstate.augment.vmap_new_states(state_tag='new1', axis_size=-3)
421
420
  def faulty_init():
422
421
  foo.init_state()
423
422
 
@@ -428,9 +427,9 @@ class TestVMAPNewStatesEdgeCases(unittest.TestCase):
428
427
 
429
428
  # Simulate an incompatible shapes scenario:
430
429
  # We intentionally assign a state with a different shape than expected.
431
- @bst.augment.vmap_new_states(state_tag='new1', axis_size=5)
430
+ @brainstate.augment.vmap_new_states(state_tag='new1', axis_size=5)
432
431
  def faulty_init():
433
432
  # Modify state to produce an incompatible shape
434
- foo.c = bst.State(jnp.arange(3)) # Original expected shape is (4,)
433
+ foo.c = brainstate.State(jnp.arange(3)) # Original expected shape is (4,)
435
434
 
436
435
  faulty_init()
@@ -13,34 +13,32 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import jax
19
17
  import jax.numpy as jnp
20
18
  from absl.testing import absltest
21
19
 
22
- import brainstate as bst
20
+ import brainstate
23
21
 
24
22
 
25
23
  class TestRemat(absltest.TestCase):
26
24
  def test_basic_remat(self):
27
- module = bst.compile.remat(bst.nn.Linear(2, 3))
25
+ module = brainstate.compile.remat(brainstate.nn.Linear(2, 3))
28
26
  y = module(jnp.ones((1, 2)))
29
27
  assert y.shape == (1, 3)
30
28
 
31
29
  def test_remat_with_scan(self):
32
- class ScanLinear(bst.nn.Module):
30
+ class ScanLinear(brainstate.nn.Module):
33
31
  def __init__(self):
34
32
  super().__init__()
35
- self.linear = bst.nn.Linear(3, 3)
33
+ self.linear = brainstate.nn.Linear(3, 3)
36
34
 
37
35
  def __call__(self, x: jax.Array):
38
- @bst.compile.remat
36
+ @brainstate.compile.remat
39
37
  def fun(x: jax.Array, _):
40
38
  x = self.linear(x)
41
39
  return x, None
42
40
 
43
- return bst.compile.scan(fun, x, None, length=10)[0]
41
+ return brainstate.compile.scan(fun, x, None, length=10)[0]
44
42
 
45
43
  m = ScanLinear()
46
44
 
@@ -12,27 +12,26 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from __future__ import annotations
16
15
 
17
16
  import unittest
18
17
 
19
18
  import jax
20
19
  import jax.numpy as jnp
21
20
 
22
- import brainstate as bst
21
+ import brainstate
23
22
 
24
23
 
25
24
  class TestCond(unittest.TestCase):
26
25
  def test1(self):
27
- bst.random.seed(1)
28
- bst.compile.cond(True, lambda: bst.random.random(10), lambda: bst.random.random(10))
29
- bst.compile.cond(False, lambda: bst.random.random(10), lambda: bst.random.random(10))
26
+ brainstate.random.seed(1)
27
+ brainstate.compile.cond(True, lambda: brainstate.random.random(10), lambda: brainstate.random.random(10))
28
+ brainstate.compile.cond(False, lambda: brainstate.random.random(10), lambda: brainstate.random.random(10))
30
29
 
31
30
  def test2(self):
32
- st1 = bst.State(bst.random.rand(10))
33
- st2 = bst.State(bst.random.rand(2))
34
- st3 = bst.State(bst.random.rand(5))
35
- st4 = bst.State(bst.random.rand(2, 10))
31
+ st1 = brainstate.State(brainstate.random.rand(10))
32
+ st2 = brainstate.State(brainstate.random.rand(2))
33
+ st3 = brainstate.State(brainstate.random.rand(5))
34
+ st4 = brainstate.State(brainstate.random.rand(2, 10))
36
35
 
37
36
  def true_fun(x):
38
37
  st1.value = st2.value @ st4.value + x
@@ -40,7 +39,7 @@ class TestCond(unittest.TestCase):
40
39
  def false_fun(x):
41
40
  st3.value = (st3.value + 1.) * x
42
41
 
43
- bst.compile.cond(True, true_fun, false_fun, 2.)
42
+ brainstate.compile.cond(True, true_fun, false_fun, 2.)
44
43
  assert not isinstance(st1.value, jax.core.Tracer)
45
44
  assert not isinstance(st2.value, jax.core.Tracer)
46
45
  assert not isinstance(st3.value, jax.core.Tracer)
@@ -66,7 +65,7 @@ class TestSwitch(unittest.TestCase):
66
65
  return branches[2](x)
67
66
 
68
67
  def cfun(x):
69
- return bst.compile.switch(x, branches, x)
68
+ return brainstate.compile.switch(x, branches, x)
70
69
 
71
70
  self.assertEqual(fun(-1), cfun(-1))
72
71
  self.assertEqual(fun(0), cfun(0))
@@ -90,7 +89,7 @@ class TestSwitch(unittest.TestCase):
90
89
  return branches[i](x, x)
91
90
 
92
91
  def cfun(x):
93
- return bst.compile.switch(x, branches, x, x)
92
+ return brainstate.compile.switch(x, branches, x, x)
94
93
 
95
94
  self.assertEqual(fun(-1), cfun(-1))
96
95
  self.assertEqual(fun(0), cfun(0))
@@ -123,13 +122,13 @@ class TestSwitch(unittest.TestCase):
123
122
  branches3 = branches2 + [lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
124
123
 
125
124
  def fun1(x, i):
126
- return bst.compile.switch(i + 1, branches1, x)
125
+ return brainstate.compile.switch(i + 1, branches1, x)
127
126
 
128
127
  def fun2(x, i):
129
- return bst.compile.switch(i + 1, branches2, x)
128
+ return brainstate.compile.switch(i + 1, branches2, x)
130
129
 
131
130
  def fun3(x, i):
132
- return bst.compile.switch(i + 1, branches3, x)
131
+ return brainstate.compile.switch(i + 1, branches3, x)
133
132
 
134
133
  fwd1, bwd1 = get_conds(fun1)
135
134
  fwd2, bwd2 = get_conds(fun2)
@@ -149,7 +148,7 @@ class TestSwitch(unittest.TestCase):
149
148
 
150
149
  def testOneBranchSwitch(self):
151
150
  branch = lambda x: -x
152
- f = lambda i, x: bst.compile.switch(i, [branch], x)
151
+ f = lambda i, x: brainstate.compile.switch(i, [branch], x)
153
152
  x = 7.
154
153
  self.assertEqual(f(-1, x), branch(x))
155
154
  self.assertEqual(f(0, x), branch(x))
@@ -167,12 +166,12 @@ class TestSwitch(unittest.TestCase):
167
166
  class TestIfElse(unittest.TestCase):
168
167
  def test1(self):
169
168
  def f(a):
170
- return bst.compile.ifelse(conditions=[a < 0,
171
- a >= 0 and a < 2,
172
- a >= 2 and a < 5,
173
- a >= 5 and a < 10,
174
- a >= 10],
175
- branches=[lambda: 1,
169
+ return brainstate.compile.ifelse(conditions=[a < 0,
170
+ a >= 0 and a < 2,
171
+ a >= 2 and a < 5,
172
+ a >= 5 and a < 10,
173
+ a >= 10],
174
+ branches=[lambda: 1,
176
175
  lambda: 2,
177
176
  lambda: 3,
178
177
  lambda: 4,
@@ -184,38 +183,38 @@ class TestIfElse(unittest.TestCase):
184
183
 
185
184
  def test_vmap(self):
186
185
  def f(operands):
187
- f = lambda a: bst.compile.ifelse([a > 10,
188
- jnp.logical_and(a <= 10, a > 5),
189
- jnp.logical_and(a <= 5, a > 2),
190
- jnp.logical_and(a <= 2, a > 0),
191
- a <= 0],
192
- [lambda _: 1,
186
+ f = lambda a: brainstate.compile.ifelse([a > 10,
187
+ jnp.logical_and(a <= 10, a > 5),
188
+ jnp.logical_and(a <= 5, a > 2),
189
+ jnp.logical_and(a <= 2, a > 0),
190
+ a <= 0],
191
+ [lambda _: 1,
193
192
  lambda _: 2,
194
193
  lambda _: 3,
195
194
  lambda _: 4,
196
195
  lambda _: 5, ],
197
- a)
196
+ a)
198
197
  return jax.vmap(f)(operands)
199
198
 
200
- r = f(bst.random.randint(-20, 20, 200))
199
+ r = f(brainstate.random.randint(-20, 20, 200))
201
200
  self.assertTrue(r.size == 200)
202
201
 
203
202
  def test_grad1(self):
204
203
  def F2(x):
205
- return bst.compile.ifelse((x >= 10, x < 10),
206
- [lambda x: x, lambda x: x ** 2, ],
207
- x)
204
+ return brainstate.compile.ifelse((x >= 10, x < 10),
205
+ [lambda x: x, lambda x: x ** 2, ],
206
+ x)
208
207
 
209
208
  self.assertTrue(jax.grad(F2)(9.0) == 18.)
210
209
  self.assertTrue(jax.grad(F2)(11.0) == 1.)
211
210
 
212
211
  def test_grad2(self):
213
212
  def F3(x):
214
- return bst.compile.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
215
- [lambda x: x,
213
+ return brainstate.compile.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
214
+ [lambda x: x,
216
215
  lambda x: x ** 2,
217
216
  lambda x: x ** 4, ],
218
- x)
217
+ x)
219
218
 
220
219
  self.assertTrue(jax.grad(F3)(9.0) == 18.)
221
220
  self.assertTrue(jax.grad(F3)(11.0) == 1.)
@@ -13,43 +13,40 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import unittest
19
17
 
20
18
  import jax
21
19
  import jax.numpy as jnp
22
- import jaxlib.xla_extension
23
20
 
24
- import brainstate as bst
21
+ import brainstate
25
22
 
26
23
 
27
24
  class TestJitError(unittest.TestCase):
28
25
  def test1(self):
29
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
30
- bst.compile.jit_error_if(True, 'error')
26
+ with self.assertRaises(Exception):
27
+ brainstate.compile.jit_error_if(True, 'error')
31
28
 
32
29
  def err_f(x):
33
30
  raise ValueError(f'error: {x}')
34
31
 
35
- bst.compile.jit_error_if(False, err_f, 1.)
36
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
37
- bst.compile.jit_error_if(True, err_f, 1.)
32
+ brainstate.compile.jit_error_if(False, err_f, 1.)
33
+ with self.assertRaises(Exception):
34
+ brainstate.compile.jit_error_if(True, err_f, 1.)
38
35
 
39
36
  def test_vmap(self):
40
37
  def f(x):
41
- bst.compile.jit_error_if(x, 'error: {x}', x=x)
38
+ brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
42
39
 
43
40
  jax.vmap(f)(jnp.array([False, False, False]))
44
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
41
+ with self.assertRaises(Exception):
45
42
  jax.vmap(f)(jnp.array([True, False, False]))
46
43
 
47
44
  def test_vmap_vmap(self):
48
45
  def f(x):
49
- bst.compile.jit_error_if(x, 'error: {x}', x=x)
46
+ brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
50
47
 
51
48
  jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
52
49
  [False, False, False]]))
53
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
50
+ with self.assertRaises(Exception):
54
51
  jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
55
52
  [True, False, False]]))