brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +10 -3
  3. brainstate/_state.py +178 -178
  4. brainstate/_utils.py +0 -1
  5. brainstate/augment/_autograd.py +0 -2
  6. brainstate/augment/_autograd_test.py +132 -133
  7. brainstate/augment/_eval_shape.py +0 -2
  8. brainstate/augment/_eval_shape_test.py +7 -9
  9. brainstate/augment/_mapping.py +2 -3
  10. brainstate/augment/_mapping_test.py +75 -76
  11. brainstate/augment/_random.py +0 -2
  12. brainstate/compile/_ad_checkpoint.py +0 -2
  13. brainstate/compile/_ad_checkpoint_test.py +6 -8
  14. brainstate/compile/_conditions.py +0 -2
  15. brainstate/compile/_conditions_test.py +35 -36
  16. brainstate/compile/_error_if.py +0 -2
  17. brainstate/compile/_error_if_test.py +10 -13
  18. brainstate/compile/_jit.py +9 -8
  19. brainstate/compile/_loop_collect_return.py +0 -2
  20. brainstate/compile/_loop_collect_return_test.py +7 -9
  21. brainstate/compile/_loop_no_collection.py +0 -2
  22. brainstate/compile/_loop_no_collection_test.py +7 -8
  23. brainstate/compile/_make_jaxpr.py +30 -17
  24. brainstate/compile/_make_jaxpr_test.py +20 -20
  25. brainstate/compile/_progress_bar.py +0 -1
  26. brainstate/compile/_unvmap.py +0 -1
  27. brainstate/compile/_util.py +0 -2
  28. brainstate/environ.py +0 -2
  29. brainstate/functional/_activations.py +0 -2
  30. brainstate/functional/_activations_test.py +61 -61
  31. brainstate/functional/_normalization.py +0 -2
  32. brainstate/functional/_others.py +0 -2
  33. brainstate/functional/_spikes.py +0 -1
  34. brainstate/graph/_graph_node.py +1 -3
  35. brainstate/graph/_graph_node_test.py +16 -18
  36. brainstate/graph/_graph_operation.py +4 -2
  37. brainstate/graph/_graph_operation_test.py +154 -156
  38. brainstate/init/_base.py +0 -2
  39. brainstate/init/_generic.py +0 -1
  40. brainstate/init/_random_inits.py +0 -1
  41. brainstate/init/_random_inits_test.py +20 -21
  42. brainstate/init/_regular_inits.py +0 -2
  43. brainstate/init/_regular_inits_test.py +4 -5
  44. brainstate/mixin.py +0 -2
  45. brainstate/nn/_collective_ops.py +0 -3
  46. brainstate/nn/_collective_ops_test.py +8 -8
  47. brainstate/nn/_common.py +0 -2
  48. brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
  49. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
  50. brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
  51. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
  52. brainstate/nn/_dyn_impl/_inputs.py +0 -1
  53. brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
  54. brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
  55. brainstate/nn/_dyn_impl/_readout.py +0 -1
  56. brainstate/nn/_dyn_impl/_readout_test.py +9 -10
  57. brainstate/nn/_dynamics/_dynamics_base.py +0 -1
  58. brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
  59. brainstate/nn/_dynamics/_projection_base.py +0 -1
  60. brainstate/nn/_dynamics/_state_delay.py +0 -2
  61. brainstate/nn/_dynamics/_synouts.py +0 -2
  62. brainstate/nn/_dynamics/_synouts_test.py +4 -5
  63. brainstate/nn/_elementwise/_dropout.py +0 -2
  64. brainstate/nn/_elementwise/_dropout_test.py +9 -9
  65. brainstate/nn/_elementwise/_elementwise.py +0 -2
  66. brainstate/nn/_elementwise/_elementwise_test.py +57 -59
  67. brainstate/nn/_event/_fixedprob_mv.py +0 -1
  68. brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
  69. brainstate/nn/_event/_linear_mv.py +0 -2
  70. brainstate/nn/_event/_linear_mv_test.py +0 -1
  71. brainstate/nn/_exp_euler.py +0 -2
  72. brainstate/nn/_exp_euler_test.py +5 -6
  73. brainstate/nn/_interaction/_conv.py +0 -2
  74. brainstate/nn/_interaction/_conv_test.py +31 -33
  75. brainstate/nn/_interaction/_embedding.py +0 -1
  76. brainstate/nn/_interaction/_linear.py +0 -2
  77. brainstate/nn/_interaction/_linear_test.py +15 -17
  78. brainstate/nn/_interaction/_normalizations.py +0 -2
  79. brainstate/nn/_interaction/_normalizations_test.py +10 -12
  80. brainstate/nn/_interaction/_poolings.py +0 -2
  81. brainstate/nn/_interaction/_poolings_test.py +19 -21
  82. brainstate/nn/_module.py +0 -1
  83. brainstate/nn/_module_test.py +34 -37
  84. brainstate/nn/metrics.py +0 -2
  85. brainstate/optim/_base.py +0 -2
  86. brainstate/optim/_lr_scheduler.py +0 -1
  87. brainstate/optim/_lr_scheduler_test.py +3 -3
  88. brainstate/optim/_optax_optimizer.py +0 -2
  89. brainstate/optim/_optax_optimizer_test.py +8 -9
  90. brainstate/optim/_sgd_optimizer.py +0 -1
  91. brainstate/random/_rand_funs.py +0 -1
  92. brainstate/random/_rand_funs_test.py +183 -184
  93. brainstate/random/_rand_seed.py +0 -1
  94. brainstate/random/_rand_seed_test.py +10 -12
  95. brainstate/random/_rand_state.py +0 -1
  96. brainstate/surrogate.py +0 -1
  97. brainstate/typing.py +0 -2
  98. brainstate/util/_caller.py +4 -6
  99. brainstate/util/_others.py +0 -2
  100. brainstate/util/_pretty_pytree.py +201 -150
  101. brainstate/util/_pretty_repr.py +0 -2
  102. brainstate/util/_pretty_table.py +57 -3
  103. brainstate/util/_scaling.py +0 -2
  104. brainstate/util/_struct.py +0 -2
  105. brainstate/util/filter.py +0 -2
  106. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
  107. brainstate-0.1.2.dist-info/RECORD +133 -0
  108. brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
  109. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
  110. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
  111. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import platform
19
18
  import unittest
@@ -23,110 +22,110 @@ import jax.random as jr
23
22
  import numpy as np
24
23
  import pytest
25
24
 
26
- import brainstate as bst
25
+ import brainstate
27
26
 
28
27
 
29
28
  class TestRandom(unittest.TestCase):
30
29
 
31
30
  def test_rand(self):
32
- bst.random.seed()
33
- a = bst.random.rand(3, 2)
31
+ brainstate.random.seed()
32
+ a = brainstate.random.rand(3, 2)
34
33
  self.assertTupleEqual(a.shape, (3, 2))
35
34
  self.assertTrue((a >= 0).all() and (a < 1).all())
36
35
 
37
36
  key = jr.PRNGKey(123)
38
37
  jres = jr.uniform(key, shape=(10, 100))
39
- self.assertTrue(jnp.allclose(jres, bst.random.rand(10, 100, key=key)))
40
- self.assertTrue(jnp.allclose(jres, bst.random.rand(10, 100, key=123)))
38
+ self.assertTrue(jnp.allclose(jres, brainstate.random.rand(10, 100, key=key)))
39
+ self.assertTrue(jnp.allclose(jres, brainstate.random.rand(10, 100, key=123)))
41
40
 
42
41
  def test_randint1(self):
43
- bst.random.seed()
44
- a = bst.random.randint(5)
42
+ brainstate.random.seed()
43
+ a = brainstate.random.randint(5)
45
44
  self.assertTupleEqual(a.shape, ())
46
45
  self.assertTrue(0 <= a < 5)
47
46
 
48
47
  def test_randint2(self):
49
- bst.random.seed()
50
- a = bst.random.randint(2, 6, size=(4, 3))
48
+ brainstate.random.seed()
49
+ a = brainstate.random.randint(2, 6, size=(4, 3))
51
50
  self.assertTupleEqual(a.shape, (4, 3))
52
51
  self.assertTrue((a >= 2).all() and (a < 6).all())
53
52
 
54
53
  def test_randint3(self):
55
- bst.random.seed()
56
- a = bst.random.randint([1, 2, 3], [10, 7, 8])
54
+ brainstate.random.seed()
55
+ a = brainstate.random.randint([1, 2, 3], [10, 7, 8])
57
56
  self.assertTupleEqual(a.shape, (3,))
58
57
  self.assertTrue((a - jnp.array([1, 2, 3]) >= 0).all()
59
58
  and (-a + jnp.array([10, 7, 8]) > 0).all())
60
59
 
61
60
  def test_randint4(self):
62
- bst.random.seed()
63
- a = bst.random.randint([1, 2, 3], [10, 7, 8], size=(2, 3))
61
+ brainstate.random.seed()
62
+ a = brainstate.random.randint([1, 2, 3], [10, 7, 8], size=(2, 3))
64
63
  self.assertTupleEqual(a.shape, (2, 3))
65
64
 
66
65
  def test_randn(self):
67
- bst.random.seed()
68
- a = bst.random.randn(3, 2)
66
+ brainstate.random.seed()
67
+ a = brainstate.random.randn(3, 2)
69
68
  self.assertTupleEqual(a.shape, (3, 2))
70
69
 
71
70
  def test_random1(self):
72
- bst.random.seed()
73
- a = bst.random.random()
71
+ brainstate.random.seed()
72
+ a = brainstate.random.random()
74
73
  self.assertTrue(0. <= a < 1)
75
74
 
76
75
  def test_random2(self):
77
- bst.random.seed()
78
- a = bst.random.random(size=(3, 2))
76
+ brainstate.random.seed()
77
+ a = brainstate.random.random(size=(3, 2))
79
78
  self.assertTupleEqual(a.shape, (3, 2))
80
79
  self.assertTrue((a >= 0).all() and (a < 1).all())
81
80
 
82
81
  def test_random_sample(self):
83
- bst.random.seed()
84
- a = bst.random.random_sample(size=(3, 2))
82
+ brainstate.random.seed()
83
+ a = brainstate.random.random_sample(size=(3, 2))
85
84
  self.assertTupleEqual(a.shape, (3, 2))
86
85
  self.assertTrue((a >= 0).all() and (a < 1).all())
87
86
 
88
87
  def test_choice1(self):
89
- bst.random.seed()
90
- a = bst.random.choice(5)
88
+ brainstate.random.seed()
89
+ a = brainstate.random.choice(5)
91
90
  self.assertTupleEqual(jnp.shape(a), ())
92
91
  self.assertTrue(0 <= a < 5)
93
92
 
94
93
  def test_choice2(self):
95
- bst.random.seed()
96
- a = bst.random.choice(5, 3, p=[0.1, 0.4, 0.2, 0., 0.3])
94
+ brainstate.random.seed()
95
+ a = brainstate.random.choice(5, 3, p=[0.1, 0.4, 0.2, 0., 0.3])
97
96
  self.assertTupleEqual(a.shape, (3,))
98
97
  self.assertTrue((a >= 0).all() and (a < 5).all())
99
98
 
100
99
  def test_choice3(self):
101
- bst.random.seed()
102
- a = bst.random.choice(jnp.arange(2, 20), size=(4, 3), replace=False)
100
+ brainstate.random.seed()
101
+ a = brainstate.random.choice(jnp.arange(2, 20), size=(4, 3), replace=False)
103
102
  self.assertTupleEqual(a.shape, (4, 3))
104
103
  self.assertTrue((a >= 2).all() and (a < 20).all())
105
104
  self.assertEqual(len(jnp.unique(a)), 12)
106
105
 
107
106
  def test_permutation1(self):
108
- bst.random.seed()
109
- a = bst.random.permutation(10)
107
+ brainstate.random.seed()
108
+ a = brainstate.random.permutation(10)
110
109
  self.assertTupleEqual(a.shape, (10,))
111
110
  self.assertEqual(len(jnp.unique(a)), 10)
112
111
 
113
112
  def test_permutation2(self):
114
- bst.random.seed()
115
- a = bst.random.permutation(jnp.arange(10))
113
+ brainstate.random.seed()
114
+ a = brainstate.random.permutation(jnp.arange(10))
116
115
  self.assertTupleEqual(a.shape, (10,))
117
116
  self.assertEqual(len(jnp.unique(a)), 10)
118
117
 
119
118
  def test_shuffle1(self):
120
- bst.random.seed()
119
+ brainstate.random.seed()
121
120
  a = jnp.arange(10)
122
- bst.random.shuffle(a)
121
+ brainstate.random.shuffle(a)
123
122
  self.assertTupleEqual(a.shape, (10,))
124
123
  self.assertEqual(len(jnp.unique(a)), 10)
125
124
 
126
125
  def test_shuffle2(self):
127
- bst.random.seed()
126
+ brainstate.random.seed()
128
127
  a = jnp.arange(12).reshape(4, 3)
129
- bst.random.shuffle(a, axis=1)
128
+ brainstate.random.shuffle(a, axis=1)
130
129
  self.assertTupleEqual(a.shape, (4, 3))
131
130
  self.assertEqual(len(jnp.unique(a)), 12)
132
131
 
@@ -135,173 +134,173 @@ class TestRandom(unittest.TestCase):
135
134
  self.assertEqual(uni, jnp.asarray([3]))
136
135
 
137
136
  def test_beta1(self):
138
- bst.random.seed()
139
- a = bst.random.beta(2, 2)
137
+ brainstate.random.seed()
138
+ a = brainstate.random.beta(2, 2)
140
139
  self.assertTupleEqual(a.shape, ())
141
140
 
142
141
  def test_beta2(self):
143
- bst.random.seed()
144
- a = bst.random.beta([2, 2, 3], 2, size=(3,))
142
+ brainstate.random.seed()
143
+ a = brainstate.random.beta([2, 2, 3], 2, size=(3,))
145
144
  self.assertTupleEqual(a.shape, (3,))
146
145
 
147
146
  def test_exponential1(self):
148
- bst.random.seed()
149
- a = bst.random.exponential(10., size=[3, 2])
147
+ brainstate.random.seed()
148
+ a = brainstate.random.exponential(10., size=[3, 2])
150
149
  self.assertTupleEqual(a.shape, (3, 2))
151
150
 
152
151
  def test_exponential2(self):
153
- bst.random.seed()
154
- a = bst.random.exponential([1., 2., 5.])
152
+ brainstate.random.seed()
153
+ a = brainstate.random.exponential([1., 2., 5.])
155
154
  self.assertTupleEqual(a.shape, (3,))
156
155
 
157
156
  def test_gamma(self):
158
- bst.random.seed()
159
- a = bst.random.gamma(2, 10., size=[3, 2])
157
+ brainstate.random.seed()
158
+ a = brainstate.random.gamma(2, 10., size=[3, 2])
160
159
  self.assertTupleEqual(a.shape, (3, 2))
161
160
 
162
161
  def test_gumbel(self):
163
- bst.random.seed()
164
- a = bst.random.gumbel(0., 2., size=[3, 2])
162
+ brainstate.random.seed()
163
+ a = brainstate.random.gumbel(0., 2., size=[3, 2])
165
164
  self.assertTupleEqual(a.shape, (3, 2))
166
165
 
167
166
  def test_laplace(self):
168
- bst.random.seed()
169
- a = bst.random.laplace(0., 2., size=[3, 2])
167
+ brainstate.random.seed()
168
+ a = brainstate.random.laplace(0., 2., size=[3, 2])
170
169
  self.assertTupleEqual(a.shape, (3, 2))
171
170
 
172
171
  def test_logistic(self):
173
- bst.random.seed()
174
- a = bst.random.logistic(0., 2., size=[3, 2])
172
+ brainstate.random.seed()
173
+ a = brainstate.random.logistic(0., 2., size=[3, 2])
175
174
  self.assertTupleEqual(a.shape, (3, 2))
176
175
 
177
176
  def test_normal1(self):
178
- bst.random.seed()
179
- a = bst.random.normal()
177
+ brainstate.random.seed()
178
+ a = brainstate.random.normal()
180
179
  self.assertTupleEqual(a.shape, ())
181
180
 
182
181
  def test_normal2(self):
183
- bst.random.seed()
184
- a = bst.random.normal(loc=[0., 2., 4.], scale=[1., 2., 3.])
182
+ brainstate.random.seed()
183
+ a = brainstate.random.normal(loc=[0., 2., 4.], scale=[1., 2., 3.])
185
184
  self.assertTupleEqual(a.shape, (3,))
186
185
 
187
186
  def test_normal3(self):
188
- bst.random.seed()
189
- a = bst.random.normal(loc=[0., 2., 4.], scale=[[1., 2., 3.], [1., 1., 1.]])
187
+ brainstate.random.seed()
188
+ a = brainstate.random.normal(loc=[0., 2., 4.], scale=[[1., 2., 3.], [1., 1., 1.]])
190
189
  print(a)
191
190
  self.assertTupleEqual(a.shape, (2, 3))
192
191
 
193
192
  def test_pareto(self):
194
- bst.random.seed()
195
- a = bst.random.pareto([1, 2, 2])
193
+ brainstate.random.seed()
194
+ a = brainstate.random.pareto([1, 2, 2])
196
195
  self.assertTupleEqual(a.shape, (3,))
197
196
 
198
197
  def test_poisson(self):
199
- bst.random.seed()
200
- a = bst.random.poisson([1., 2., 2.], size=3)
198
+ brainstate.random.seed()
199
+ a = brainstate.random.poisson([1., 2., 2.], size=3)
201
200
  self.assertTupleEqual(a.shape, (3,))
202
201
 
203
202
  def test_standard_cauchy(self):
204
- bst.random.seed()
205
- a = bst.random.standard_cauchy(size=(3, 2))
203
+ brainstate.random.seed()
204
+ a = brainstate.random.standard_cauchy(size=(3, 2))
206
205
  self.assertTupleEqual(a.shape, (3, 2))
207
206
 
208
207
  def test_standard_exponential(self):
209
- bst.random.seed()
210
- a = bst.random.standard_exponential(size=(3, 2))
208
+ brainstate.random.seed()
209
+ a = brainstate.random.standard_exponential(size=(3, 2))
211
210
  self.assertTupleEqual(a.shape, (3, 2))
212
211
 
213
212
  def test_standard_gamma(self):
214
- bst.random.seed()
215
- a = bst.random.standard_gamma(shape=[1, 2, 4], size=3)
213
+ brainstate.random.seed()
214
+ a = brainstate.random.standard_gamma(shape=[1, 2, 4], size=3)
216
215
  self.assertTupleEqual(a.shape, (3,))
217
216
 
218
217
  def test_standard_normal(self):
219
- bst.random.seed()
220
- a = bst.random.standard_normal(size=(3, 2))
218
+ brainstate.random.seed()
219
+ a = brainstate.random.standard_normal(size=(3, 2))
221
220
  self.assertTupleEqual(a.shape, (3, 2))
222
221
 
223
222
  def test_standard_t(self):
224
- bst.random.seed()
225
- a = bst.random.standard_t(df=[1, 2, 4], size=3)
223
+ brainstate.random.seed()
224
+ a = brainstate.random.standard_t(df=[1, 2, 4], size=3)
226
225
  self.assertTupleEqual(a.shape, (3,))
227
226
 
228
227
  def test_standard_uniform1(self):
229
- bst.random.seed()
230
- a = bst.random.uniform()
228
+ brainstate.random.seed()
229
+ a = brainstate.random.uniform()
231
230
  self.assertTupleEqual(a.shape, ())
232
231
  self.assertTrue(0 <= a < 1)
233
232
 
234
233
  def test_uniform2(self):
235
- bst.random.seed()
236
- a = bst.random.uniform(low=[-1., 5., 2.], high=[2., 6., 10.], size=3)
234
+ brainstate.random.seed()
235
+ a = brainstate.random.uniform(low=[-1., 5., 2.], high=[2., 6., 10.], size=3)
237
236
  self.assertTupleEqual(a.shape, (3,))
238
237
  self.assertTrue((a - jnp.array([-1., 5., 2.]) >= 0).all()
239
238
  and (-a + jnp.array([2., 6., 10.]) > 0).all())
240
239
 
241
240
  def test_uniform3(self):
242
- bst.random.seed()
243
- a = bst.random.uniform(low=-1., high=[2., 6., 10.], size=(2, 3))
241
+ brainstate.random.seed()
242
+ a = brainstate.random.uniform(low=-1., high=[2., 6., 10.], size=(2, 3))
244
243
  self.assertTupleEqual(a.shape, (2, 3))
245
244
 
246
245
  def test_uniform4(self):
247
- bst.random.seed()
248
- a = bst.random.uniform(low=[-1., 5., 2.], high=[[2., 6., 10.], [10., 10., 10.]])
246
+ brainstate.random.seed()
247
+ a = brainstate.random.uniform(low=[-1., 5., 2.], high=[[2., 6., 10.], [10., 10., 10.]])
249
248
  self.assertTupleEqual(a.shape, (2, 3))
250
249
 
251
250
  def test_truncated_normal1(self):
252
- bst.random.seed()
253
- a = bst.random.truncated_normal(-1., 1.)
251
+ brainstate.random.seed()
252
+ a = brainstate.random.truncated_normal(-1., 1.)
254
253
  self.assertTupleEqual(a.shape, ())
255
254
  self.assertTrue(-1. <= a <= 1.)
256
255
 
257
256
  def test_truncated_normal2(self):
258
- bst.random.seed()
259
- a = bst.random.truncated_normal(-1., [1., 2., 1.], size=(4, 3))
257
+ brainstate.random.seed()
258
+ a = brainstate.random.truncated_normal(-1., [1., 2., 1.], size=(4, 3))
260
259
  self.assertTupleEqual(a.shape, (4, 3))
261
260
 
262
261
  def test_truncated_normal3(self):
263
- bst.random.seed()
264
- a = bst.random.truncated_normal([-1., 0., 1.], [[2., 2., 4.], [2., 2., 4.]])
262
+ brainstate.random.seed()
263
+ a = brainstate.random.truncated_normal([-1., 0., 1.], [[2., 2., 4.], [2., 2., 4.]])
265
264
  self.assertTupleEqual(a.shape, (2, 3))
266
265
  self.assertTrue((a - jnp.array([-1., 0., 1.]) >= 0.).all()
267
266
  and (- a + jnp.array([2., 2., 4.]) >= 0.).all())
268
267
 
269
268
  def test_bernoulli1(self):
270
- bst.random.seed()
271
- a = bst.random.bernoulli()
269
+ brainstate.random.seed()
270
+ a = brainstate.random.bernoulli()
272
271
  self.assertTupleEqual(a.shape, ())
273
272
  self.assertTrue(a == 0 or a == 1)
274
273
 
275
274
  def test_bernoulli2(self):
276
- bst.random.seed()
277
- a = bst.random.bernoulli([0.5, 0.6, 0.8])
275
+ brainstate.random.seed()
276
+ a = brainstate.random.bernoulli([0.5, 0.6, 0.8])
278
277
  self.assertTupleEqual(a.shape, (3,))
279
278
  self.assertTrue(jnp.logical_xor(a == 1, a == 0).all())
280
279
 
281
280
  def test_bernoulli3(self):
282
- bst.random.seed()
283
- a = bst.random.bernoulli([0.5, 0.6], size=(3, 2))
281
+ brainstate.random.seed()
282
+ a = brainstate.random.bernoulli([0.5, 0.6], size=(3, 2))
284
283
  self.assertTupleEqual(a.shape, (3, 2))
285
284
  self.assertTrue(jnp.logical_xor(a == 1, a == 0).all())
286
285
 
287
286
  def test_lognormal1(self):
288
- bst.random.seed()
289
- a = bst.random.lognormal()
287
+ brainstate.random.seed()
288
+ a = brainstate.random.lognormal()
290
289
  self.assertTupleEqual(a.shape, ())
291
290
 
292
291
  def test_lognormal2(self):
293
- bst.random.seed()
294
- a = bst.random.lognormal(sigma=[2., 1.], size=[3, 2])
292
+ brainstate.random.seed()
293
+ a = brainstate.random.lognormal(sigma=[2., 1.], size=[3, 2])
295
294
  self.assertTupleEqual(a.shape, (3, 2))
296
295
 
297
296
  def test_lognormal3(self):
298
- bst.random.seed()
299
- a = bst.random.lognormal([2., 0.], [[2., 1.], [3., 1.2]])
297
+ brainstate.random.seed()
298
+ a = brainstate.random.lognormal([2., 0.], [[2., 1.], [3., 1.2]])
300
299
  self.assertTupleEqual(a.shape, (2, 2))
301
300
 
302
301
  def test_binomial1(self):
303
- bst.random.seed()
304
- a = bst.random.binomial(5, 0.5)
302
+ brainstate.random.seed()
303
+ a = brainstate.random.binomial(5, 0.5)
305
304
  b = np.random.binomial(5, 0.5)
306
305
  print(a)
307
306
  print(b)
@@ -309,99 +308,99 @@ class TestRandom(unittest.TestCase):
309
308
  self.assertTrue(a.dtype, int)
310
309
 
311
310
  def test_binomial2(self):
312
- bst.random.seed()
313
- a = bst.random.binomial(5, 0.5, size=(3, 2))
311
+ brainstate.random.seed()
312
+ a = brainstate.random.binomial(5, 0.5, size=(3, 2))
314
313
  self.assertTupleEqual(a.shape, (3, 2))
315
314
  self.assertTrue((a >= 0).all() and (a <= 5).all())
316
315
 
317
316
  def test_binomial3(self):
318
- bst.random.seed()
319
- a = bst.random.binomial(n=jnp.asarray([2, 3, 4]), p=jnp.asarray([[0.5, 0.5, 0.5], [0.6, 0.6, 0.6]]))
317
+ brainstate.random.seed()
318
+ a = brainstate.random.binomial(n=jnp.asarray([2, 3, 4]), p=jnp.asarray([[0.5, 0.5, 0.5], [0.6, 0.6, 0.6]]))
320
319
  self.assertTupleEqual(a.shape, (2, 3))
321
320
 
322
321
  def test_chisquare1(self):
323
- bst.random.seed()
324
- a = bst.random.chisquare(3)
322
+ brainstate.random.seed()
323
+ a = brainstate.random.chisquare(3)
325
324
  self.assertTupleEqual(a.shape, ())
326
325
  self.assertTrue(a.dtype, float)
327
326
 
328
327
  def test_chisquare2(self):
329
- bst.random.seed()
328
+ brainstate.random.seed()
330
329
  with self.assertRaises(NotImplementedError):
331
- a = bst.random.chisquare(df=[2, 3, 4])
330
+ a = brainstate.random.chisquare(df=[2, 3, 4])
332
331
 
333
332
  def test_chisquare3(self):
334
- bst.random.seed()
335
- a = bst.random.chisquare(df=2, size=100)
333
+ brainstate.random.seed()
334
+ a = brainstate.random.chisquare(df=2, size=100)
336
335
  self.assertTupleEqual(a.shape, (100,))
337
336
 
338
337
  def test_chisquare4(self):
339
- bst.random.seed()
340
- a = bst.random.chisquare(df=2, size=(100, 10))
338
+ brainstate.random.seed()
339
+ a = brainstate.random.chisquare(df=2, size=(100, 10))
341
340
  self.assertTupleEqual(a.shape, (100, 10))
342
341
 
343
342
  def test_dirichlet1(self):
344
- bst.random.seed()
345
- a = bst.random.dirichlet((10, 5, 3))
343
+ brainstate.random.seed()
344
+ a = brainstate.random.dirichlet((10, 5, 3))
346
345
  self.assertTupleEqual(a.shape, (3,))
347
346
 
348
347
  def test_dirichlet2(self):
349
- bst.random.seed()
350
- a = bst.random.dirichlet((10, 5, 3), 20)
348
+ brainstate.random.seed()
349
+ a = brainstate.random.dirichlet((10, 5, 3), 20)
351
350
  self.assertTupleEqual(a.shape, (20, 3))
352
351
 
353
352
  def test_f(self):
354
- bst.random.seed()
355
- a = bst.random.f(1., 48., 100)
353
+ brainstate.random.seed()
354
+ a = brainstate.random.f(1., 48., 100)
356
355
  self.assertTupleEqual(a.shape, (100,))
357
356
 
358
357
  def test_geometric(self):
359
- bst.random.seed()
360
- a = bst.random.geometric([0.7, 0.5, 0.2])
358
+ brainstate.random.seed()
359
+ a = brainstate.random.geometric([0.7, 0.5, 0.2])
361
360
  self.assertTupleEqual(a.shape, (3,))
362
361
 
363
362
  def test_hypergeometric1(self):
364
- bst.random.seed()
365
- a = bst.random.hypergeometric(10, 10, 10, 20)
363
+ brainstate.random.seed()
364
+ a = brainstate.random.hypergeometric(10, 10, 10, 20)
366
365
  self.assertTupleEqual(a.shape, (20,))
367
366
 
368
367
  @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error')
369
368
  def test_hypergeometric2(self):
370
- bst.random.seed()
371
- a = bst.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]])
369
+ brainstate.random.seed()
370
+ a = brainstate.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]])
372
371
  self.assertTupleEqual(a.shape, (2, 2))
373
372
 
374
373
  @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error')
375
374
  def test_hypergeometric3(self):
376
- bst.random.seed()
377
- a = bst.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]], size=(3, 2, 2))
375
+ brainstate.random.seed()
376
+ a = brainstate.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]], size=(3, 2, 2))
378
377
  self.assertTupleEqual(a.shape, (3, 2, 2))
379
378
 
380
379
  def test_logseries(self):
381
- bst.random.seed()
382
- a = bst.random.logseries([0.7, 0.5, 0.2], size=[4, 3])
380
+ brainstate.random.seed()
381
+ a = brainstate.random.logseries([0.7, 0.5, 0.2], size=[4, 3])
383
382
  self.assertTupleEqual(a.shape, (4, 3))
384
383
 
385
384
  def test_multinominal1(self):
386
- bst.random.seed()
385
+ brainstate.random.seed()
387
386
  a = np.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2])
388
387
  print(a, a.shape)
389
- b = bst.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2])
388
+ b = brainstate.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2])
390
389
  print(b, b.shape)
391
390
  self.assertTupleEqual(a.shape, b.shape)
392
391
  self.assertTupleEqual(b.shape, (4, 2, 3))
393
392
 
394
393
  def test_multinominal2(self):
395
- bst.random.seed()
396
- a = bst.random.multinomial(100, (0.5, 0.2, 0.3))
394
+ brainstate.random.seed()
395
+ a = brainstate.random.multinomial(100, (0.5, 0.2, 0.3))
397
396
  self.assertTupleEqual(a.shape, (3,))
398
397
  self.assertTrue(a.sum() == 100)
399
398
 
400
399
  def test_multivariate_normal1(self):
401
- bst.random.seed()
400
+ brainstate.random.seed()
402
401
  # self.skipTest('Windows jaxlib error')
403
402
  a = np.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3)
404
- b = bst.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3)
403
+ b = brainstate.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3)
405
404
  print('test_multivariate_normal1')
406
405
  print(a)
407
406
  print(b)
@@ -409,156 +408,156 @@ class TestRandom(unittest.TestCase):
409
408
  self.assertTupleEqual(a.shape, (3, 2))
410
409
 
411
410
  def test_multivariate_normal2(self):
412
- bst.random.seed()
411
+ brainstate.random.seed()
413
412
  a = np.random.multivariate_normal([1, 2], [[1, 3], [3, 1]])
414
- b = bst.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], method='svd')
413
+ b = brainstate.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], method='svd')
415
414
  print(a)
416
415
  print(b)
417
416
  self.assertTupleEqual(a.shape, b.shape)
418
417
  self.assertTupleEqual(a.shape, (2,))
419
418
 
420
419
  def test_negative_binomial(self):
421
- bst.random.seed()
420
+ brainstate.random.seed()
422
421
  a = np.random.negative_binomial([3., 10.], 0.5)
423
- b = bst.random.negative_binomial([3., 10.], 0.5)
422
+ b = brainstate.random.negative_binomial([3., 10.], 0.5)
424
423
  print(a)
425
424
  print(b)
426
425
  self.assertTupleEqual(a.shape, b.shape)
427
426
  self.assertTupleEqual(b.shape, (2,))
428
427
 
429
428
  def test_negative_binomial2(self):
430
- bst.random.seed()
429
+ brainstate.random.seed()
431
430
  a = np.random.negative_binomial(3., 0.5, 10)
432
- b = bst.random.negative_binomial(3., 0.5, 10)
431
+ b = brainstate.random.negative_binomial(3., 0.5, 10)
433
432
  print(a)
434
433
  print(b)
435
434
  self.assertTupleEqual(a.shape, b.shape)
436
435
  self.assertTupleEqual(b.shape, (10,))
437
436
 
438
437
  def test_noncentral_chisquare(self):
439
- bst.random.seed()
438
+ brainstate.random.seed()
440
439
  a = np.random.noncentral_chisquare(3, [3., 2.], (4, 2))
441
- b = bst.random.noncentral_chisquare(3, [3., 2.], (4, 2))
440
+ b = brainstate.random.noncentral_chisquare(3, [3., 2.], (4, 2))
442
441
  self.assertTupleEqual(a.shape, b.shape)
443
442
  self.assertTupleEqual(b.shape, (4, 2))
444
443
 
445
444
  def test_noncentral_chisquare2(self):
446
- bst.random.seed()
447
- a = bst.random.noncentral_chisquare(3, [3., 2.])
445
+ brainstate.random.seed()
446
+ a = brainstate.random.noncentral_chisquare(3, [3., 2.])
448
447
  self.assertTupleEqual(a.shape, (2,))
449
448
 
450
449
  def test_noncentral_f(self):
451
- bst.random.seed()
452
- a = bst.random.noncentral_f(3, 20, 3., 100)
450
+ brainstate.random.seed()
451
+ a = brainstate.random.noncentral_f(3, 20, 3., 100)
453
452
  self.assertTupleEqual(a.shape, (100,))
454
453
 
455
454
  def test_power(self):
456
- bst.random.seed()
455
+ brainstate.random.seed()
457
456
  a = np.random.power(2, (4, 2))
458
- b = bst.random.power(2, (4, 2))
457
+ b = brainstate.random.power(2, (4, 2))
459
458
  self.assertTupleEqual(a.shape, b.shape)
460
459
  self.assertTupleEqual(b.shape, (4, 2))
461
460
 
462
461
  def test_rayleigh(self):
463
- bst.random.seed()
464
- a = bst.random.power(2., (4, 2))
462
+ brainstate.random.seed()
463
+ a = brainstate.random.power(2., (4, 2))
465
464
  self.assertTupleEqual(a.shape, (4, 2))
466
465
 
467
466
  def test_triangular(self):
468
- bst.random.seed()
469
- a = bst.random.triangular((2, 2))
467
+ brainstate.random.seed()
468
+ a = brainstate.random.triangular((2, 2))
470
469
  self.assertTupleEqual(a.shape, (2, 2))
471
470
 
472
471
  def test_vonmises(self):
473
- bst.random.seed()
472
+ brainstate.random.seed()
474
473
  a = np.random.vonmises(2., 2.)
475
- b = bst.random.vonmises(2., 2.)
474
+ b = brainstate.random.vonmises(2., 2.)
476
475
  print(a, b)
477
476
  self.assertTupleEqual(np.shape(a), b.shape)
478
477
  self.assertTupleEqual(b.shape, ())
479
478
 
480
479
  def test_vonmises2(self):
481
- bst.random.seed()
480
+ brainstate.random.seed()
482
481
  a = np.random.vonmises(2., 2., 10)
483
- b = bst.random.vonmises(2., 2., 10)
482
+ b = brainstate.random.vonmises(2., 2., 10)
484
483
  print(a, b)
485
484
  self.assertTupleEqual(a.shape, b.shape)
486
485
  self.assertTupleEqual(b.shape, (10,))
487
486
 
488
487
  def test_wald(self):
489
- bst.random.seed()
488
+ brainstate.random.seed()
490
489
  a = np.random.wald([2., 0.5], 2.)
491
- b = bst.random.wald([2., 0.5], 2.)
490
+ b = brainstate.random.wald([2., 0.5], 2.)
492
491
  self.assertTupleEqual(a.shape, b.shape)
493
492
  self.assertTupleEqual(b.shape, (2,))
494
493
 
495
494
  def test_wald2(self):
496
- bst.random.seed()
495
+ brainstate.random.seed()
497
496
  a = np.random.wald(2., 2., 100)
498
- b = bst.random.wald(2., 2., 100)
497
+ b = brainstate.random.wald(2., 2., 100)
499
498
  self.assertTupleEqual(a.shape, b.shape)
500
499
  self.assertTupleEqual(b.shape, (100,))
501
500
 
502
501
  def test_weibull(self):
503
- bst.random.seed()
504
- a = bst.random.weibull(2., (4, 2))
502
+ brainstate.random.seed()
503
+ a = brainstate.random.weibull(2., (4, 2))
505
504
  self.assertTupleEqual(a.shape, (4, 2))
506
505
 
507
506
  def test_weibull2(self):
508
- bst.random.seed()
509
- a = bst.random.weibull(2., )
507
+ brainstate.random.seed()
508
+ a = brainstate.random.weibull(2., )
510
509
  self.assertTupleEqual(a.shape, ())
511
510
 
512
511
  def test_weibull3(self):
513
- bst.random.seed()
514
- a = bst.random.weibull([2., 3.], )
512
+ brainstate.random.seed()
513
+ a = brainstate.random.weibull([2., 3.], )
515
514
  self.assertTupleEqual(a.shape, (2,))
516
515
 
517
516
  def test_weibull_min(self):
518
- bst.random.seed()
519
- a = bst.random.weibull_min(2., 2., (4, 2))
517
+ brainstate.random.seed()
518
+ a = brainstate.random.weibull_min(2., 2., (4, 2))
520
519
  self.assertTupleEqual(a.shape, (4, 2))
521
520
 
522
521
  def test_weibull_min2(self):
523
- bst.random.seed()
524
- a = bst.random.weibull_min(2., 2.)
522
+ brainstate.random.seed()
523
+ a = brainstate.random.weibull_min(2., 2.)
525
524
  self.assertTupleEqual(a.shape, ())
526
525
 
527
526
  def test_weibull_min3(self):
528
- bst.random.seed()
529
- a = bst.random.weibull_min([2., 3.], 2.)
527
+ brainstate.random.seed()
528
+ a = brainstate.random.weibull_min([2., 3.], 2.)
530
529
  self.assertTupleEqual(a.shape, (2,))
531
530
 
532
531
  def test_zipf(self):
533
- bst.random.seed()
534
- a = bst.random.zipf(2., (4, 2))
532
+ brainstate.random.seed()
533
+ a = brainstate.random.zipf(2., (4, 2))
535
534
  self.assertTupleEqual(a.shape, (4, 2))
536
535
 
537
536
  def test_zipf2(self):
538
- bst.random.seed()
537
+ brainstate.random.seed()
539
538
  a = np.random.zipf([1.1, 2.])
540
- b = bst.random.zipf([1.1, 2.])
539
+ b = brainstate.random.zipf([1.1, 2.])
541
540
  self.assertTupleEqual(a.shape, b.shape)
542
541
  self.assertTupleEqual(b.shape, (2,))
543
542
 
544
543
  def test_maxwell(self):
545
- bst.random.seed()
546
- a = bst.random.maxwell(10)
544
+ brainstate.random.seed()
545
+ a = brainstate.random.maxwell(10)
547
546
  self.assertTupleEqual(a.shape, (10,))
548
547
 
549
548
  def test_maxwell2(self):
550
- bst.random.seed()
551
- a = bst.random.maxwell()
549
+ brainstate.random.seed()
550
+ a = brainstate.random.maxwell()
552
551
  self.assertTupleEqual(a.shape, ())
553
552
 
554
553
  def test_t(self):
555
- bst.random.seed()
556
- a = bst.random.t(1., size=10)
554
+ brainstate.random.seed()
555
+ a = brainstate.random.t(1., size=10)
557
556
  self.assertTupleEqual(a.shape, (10,))
558
557
 
559
558
  def test_t2(self):
560
- bst.random.seed()
561
- a = bst.random.t([1., 2.], size=None)
559
+ brainstate.random.seed()
560
+ a = brainstate.random.t([1., 2.], size=None)
562
561
  self.assertTupleEqual(a.shape, (2,))
563
562
 
564
563
  # class TestRandomKey(unittest.TestCase):