brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -20,142 +20,162 @@ from functools import partial
20
20
 
21
21
  import jax
22
22
  import jax.numpy as jnp
23
+ import numpy as np
23
24
  import scipy.stats
24
- from absl.testing import parameterized
25
- from jax._src import test_util as jtu
25
+ from absl.testing import absltest, parameterized
26
26
  from jax.test_util import check_grads
27
27
 
28
28
  import brainstate
29
29
 
30
30
 
31
- class NNFunctionsTest(jtu.JaxTestCase):
32
- @jtu.skip_on_flag("jax_skip_slow_tests", True)
31
+ class NNFunctionsTest(parameterized.TestCase):
32
+ def setUp(self):
33
+ super().setUp()
34
+ self.rng_key = jax.random.PRNGKey(0)
35
+
36
+ def assertAllClose(self, a, b, check_dtypes=True, atol=None, rtol=None):
37
+ """Helper method for backwards compatibility with JAX test utilities."""
38
+ a = np.asarray(a)
39
+ b = np.asarray(b)
40
+ kw = {}
41
+ if atol is not None:
42
+ kw['atol'] = atol
43
+ if rtol is not None:
44
+ kw['rtol'] = rtol
45
+ np.testing.assert_allclose(a, b, **kw)
46
+ if check_dtypes:
47
+ self.assertEqual(a.dtype, b.dtype)
48
+
49
+ def assertArraysEqual(self, a, b):
50
+ """Helper method for backwards compatibility with JAX test utilities."""
51
+ np.testing.assert_array_equal(np.asarray(a), np.asarray(b))
52
+
33
53
  def testSoftplusGrad(self):
34
- check_grads(brainstate.functional.softplus, (1e-8,), order=4, )
54
+ check_grads(brainstate.nn.softplus, (1e-8,), order=4, )
35
55
 
36
56
  def testSoftplusGradZero(self):
37
- check_grads(brainstate.functional.softplus, (0.,), order=1)
57
+ check_grads(brainstate.nn.softplus, (0.,), order=1)
38
58
 
39
59
  def testSoftplusGradInf(self):
40
- self.assertAllClose(1., jax.grad(brainstate.functional.softplus)(float('inf')))
60
+ self.assertAllClose(1., jax.grad(brainstate.nn.softplus)(float('inf')), check_dtypes=False)
41
61
 
42
62
  def testSoftplusGradNegInf(self):
43
- check_grads(brainstate.functional.softplus, (-float('inf'),), order=1)
63
+ check_grads(brainstate.nn.softplus, (-float('inf'),), order=1)
44
64
 
45
65
  def testSoftplusGradNan(self):
46
- check_grads(brainstate.functional.softplus, (float('nan'),), order=1)
66
+ check_grads(brainstate.nn.softplus, (float('nan'),), order=1)
47
67
 
48
- @parameterized.parameters([int, float] + jtu.dtypes.floating + jtu.dtypes.integer)
68
+ @parameterized.parameters([int, float, jnp.float32, jnp.float64, jnp.int32, jnp.int64])
49
69
  def testSoftplusZero(self, dtype):
50
- self.assertEqual(jnp.log(dtype(2)), brainstate.functional.softplus(dtype(0)))
70
+ self.assertEqual(jnp.log(dtype(2)), brainstate.nn.softplus(dtype(0)))
51
71
 
52
72
  def testSparseplusGradZero(self):
53
- check_grads(brainstate.functional.sparse_plus, (-2.,), order=1)
73
+ check_grads(brainstate.nn.sparse_plus, (-2.,), order=1)
54
74
 
55
75
  def testSparseplusGrad(self):
56
- check_grads(brainstate.functional.sparse_plus, (0.,), order=1)
76
+ check_grads(brainstate.nn.sparse_plus, (0.,), order=1)
57
77
 
58
78
  def testSparseplusAndSparseSigmoid(self):
59
79
  self.assertAllClose(
60
- jax.grad(brainstate.functional.sparse_plus)(0.),
61
- brainstate.functional.sparse_sigmoid(0.),
80
+ jax.grad(brainstate.nn.sparse_plus)(0.),
81
+ brainstate.nn.sparse_sigmoid(0.),
62
82
  check_dtypes=False)
63
83
  self.assertAllClose(
64
- jax.grad(brainstate.functional.sparse_plus)(2.),
65
- brainstate.functional.sparse_sigmoid(2.),
84
+ jax.grad(brainstate.nn.sparse_plus)(2.),
85
+ brainstate.nn.sparse_sigmoid(2.),
66
86
  check_dtypes=False)
67
87
  self.assertAllClose(
68
- jax.grad(brainstate.functional.sparse_plus)(-2.),
69
- brainstate.functional.sparse_sigmoid(-2.),
88
+ jax.grad(brainstate.nn.sparse_plus)(-2.),
89
+ brainstate.nn.sparse_sigmoid(-2.),
70
90
  check_dtypes=False)
71
91
 
72
92
  # def testSquareplusGrad(self):
73
- # check_grads(brainstate.functional.squareplus, (1e-8,), order=4,
93
+ # check_grads(brainstate.nn.squareplus, (1e-8,), order=4,
74
94
  # )
75
95
 
76
96
  # def testSquareplusGradZero(self):
77
- # check_grads(brainstate.functional.squareplus, (0.,), order=1,
97
+ # check_grads(brainstate.nn.squareplus, (0.,), order=1,
78
98
  # )
79
99
 
80
100
  # def testSquareplusGradNegInf(self):
81
- # check_grads(brainstate.functional.squareplus, (-float('inf'),), order=1,
101
+ # check_grads(brainstate.nn.squareplus, (-float('inf'),), order=1,
82
102
  # )
83
103
 
84
104
  # def testSquareplusGradNan(self):
85
- # check_grads(brainstate.functional.squareplus, (float('nan'),), order=1,
105
+ # check_grads(brainstate.nn.squareplus, (float('nan'),), order=1,
86
106
  # )
87
107
 
88
- # @parameterized.parameters([float] + jtu.dtypes.floating)
108
+ # @parameterized.parameters([float, jnp.float32, jnp.float64])
89
109
  # def testSquareplusZero(self, dtype):
90
- # self.assertEqual(dtype(1), brainstate.functional.squareplus(dtype(0), dtype(4)))
110
+ # self.assertEqual(dtype(1), brainstate.nn.squareplus(dtype(0), dtype(4)))
91
111
  #
92
112
  # def testMishGrad(self):
93
- # check_grads(brainstate.functional.mish, (1e-8,), order=4,
113
+ # check_grads(brainstate.nn.mish, (1e-8,), order=4,
94
114
  # )
95
115
  #
96
116
  # def testMishGradZero(self):
97
- # check_grads(brainstate.functional.mish, (0.,), order=1,
117
+ # check_grads(brainstate.nn.mish, (0.,), order=1,
98
118
  # )
99
119
  #
100
120
  # def testMishGradNegInf(self):
101
- # check_grads(brainstate.functional.mish, (-float('inf'),), order=1,
121
+ # check_grads(brainstate.nn.mish, (-float('inf'),), order=1,
102
122
  # )
103
123
  #
104
124
  # def testMishGradNan(self):
105
- # check_grads(brainstate.functional.mish, (float('nan'),), order=1,
125
+ # check_grads(brainstate.nn.mish, (float('nan'),), order=1,
106
126
  # )
107
127
 
108
- @parameterized.parameters([float] + jtu.dtypes.floating)
128
+ @parameterized.parameters([float, jnp.float32, jnp.float64])
109
129
  def testMishZero(self, dtype):
110
- self.assertEqual(dtype(0), brainstate.functional.mish(dtype(0)))
130
+ self.assertEqual(dtype(0), brainstate.nn.mish(dtype(0)))
111
131
 
112
132
  def testReluGrad(self):
113
133
  rtol = None
114
- check_grads(brainstate.functional.relu, (1.,), order=3, rtol=rtol)
115
- check_grads(brainstate.functional.relu, (-1.,), order=3, rtol=rtol)
116
- jaxpr = jax.make_jaxpr(jax.grad(brainstate.functional.relu))(0.)
134
+ check_grads(brainstate.nn.relu, (1.,), order=3, rtol=rtol)
135
+ check_grads(brainstate.nn.relu, (-1.,), order=3, rtol=rtol)
136
+ jaxpr = jax.make_jaxpr(jax.grad(brainstate.nn.relu))(0.)
117
137
  self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
118
138
 
119
139
  def testRelu6Grad(self):
120
140
  rtol = None
121
- check_grads(brainstate.functional.relu6, (1.,), order=3, rtol=rtol)
122
- check_grads(brainstate.functional.relu6, (-1.,), order=3, rtol=rtol)
123
- self.assertAllClose(jax.grad(brainstate.functional.relu6)(0.), 0., check_dtypes=False)
124
- self.assertAllClose(jax.grad(brainstate.functional.relu6)(6.), 0., check_dtypes=False)
141
+ check_grads(brainstate.nn.relu6, (1.,), order=3, rtol=rtol)
142
+ check_grads(brainstate.nn.relu6, (-1.,), order=3, rtol=rtol)
143
+ self.assertAllClose(jax.grad(brainstate.nn.relu6)(0.), 0., check_dtypes=False)
144
+ self.assertAllClose(jax.grad(brainstate.nn.relu6)(6.), 0., check_dtypes=False)
125
145
 
126
146
  def testSoftplusValue(self):
127
- val = brainstate.functional.softplus(89.)
147
+ val = brainstate.nn.softplus(89.)
128
148
  self.assertAllClose(val, 89., check_dtypes=False)
129
149
 
130
150
  def testSparseplusValue(self):
131
- val = brainstate.functional.sparse_plus(89.)
151
+ val = brainstate.nn.sparse_plus(89.)
132
152
  self.assertAllClose(val, 89., check_dtypes=False)
133
153
 
134
154
  def testSparsesigmoidValue(self):
135
- self.assertAllClose(brainstate.functional.sparse_sigmoid(-2.), 0., check_dtypes=False)
136
- self.assertAllClose(brainstate.functional.sparse_sigmoid(2.), 1., check_dtypes=False)
137
- self.assertAllClose(brainstate.functional.sparse_sigmoid(0.), .5, check_dtypes=False)
155
+ self.assertAllClose(brainstate.nn.sparse_sigmoid(-2.), 0., check_dtypes=False)
156
+ self.assertAllClose(brainstate.nn.sparse_sigmoid(2.), 1., check_dtypes=False)
157
+ self.assertAllClose(brainstate.nn.sparse_sigmoid(0.), .5, check_dtypes=False)
138
158
 
139
159
  # def testSquareplusValue(self):
140
- # val = brainstate.functional.squareplus(1e3)
160
+ # val = brainstate.nn.squareplus(1e3)
141
161
  # self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
142
162
 
143
163
  def testMishValue(self):
144
- val = brainstate.functional.mish(1e3)
164
+ val = brainstate.nn.mish(1e3)
145
165
  self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
146
166
 
147
167
  def testEluValue(self):
148
- val = brainstate.functional.elu(1e4)
168
+ val = brainstate.nn.elu(1e4)
149
169
  self.assertAllClose(val, 1e4, check_dtypes=False)
150
170
 
151
171
  def testGluValue(self):
152
- val = brainstate.functional.glu(jnp.array([1.0, 0.0]), axis=0)
172
+ val = brainstate.nn.glu(jnp.array([1.0, 0.0]), axis=0)
153
173
  self.assertAllClose(val, jnp.array([0.5]))
154
174
 
155
175
  @parameterized.parameters(False, True)
156
176
  def testGeluIntType(self, approximate):
157
- val_float = brainstate.functional.gelu(jnp.array(-1.0), approximate=approximate)
158
- val_int = brainstate.functional.gelu(jnp.array(-1), approximate=approximate)
177
+ val_float = brainstate.nn.gelu(jnp.array(-1.0), approximate=approximate)
178
+ val_int = brainstate.nn.gelu(jnp.array(-1), approximate=approximate)
159
179
  self.assertAllClose(val_float, val_int)
160
180
 
161
181
  @parameterized.parameters(False, True)
@@ -163,22 +183,21 @@ class NNFunctionsTest(jtu.JaxTestCase):
163
183
  def gelu_reference(x):
164
184
  return x * scipy.stats.norm.cdf(x)
165
185
 
166
- rng = jtu.rand_default(self.rng())
167
- args_maker = lambda: [rng((4, 5, 6), jnp.float32)]
168
- self._CheckAgainstNumpy(
169
- gelu_reference, partial(brainstate.functional.gelu, approximate=approximate), args_maker,
170
- check_dtypes=False, tol=1e-3 if approximate else None)
186
+ x = jax.random.normal(self.rng_key, (4, 5, 6), dtype=jnp.float32)
187
+ expected = gelu_reference(x)
188
+ actual = brainstate.nn.gelu(x, approximate=approximate)
189
+ np.testing.assert_allclose(actual, expected, rtol=1e-2 if approximate else 1e-5, atol=1e-3 if approximate else 1e-5)
171
190
 
172
191
  @parameterized.parameters(*itertools.product(
173
192
  (jnp.float32, jnp.bfloat16, jnp.float16),
174
- (partial(brainstate.functional.gelu, approximate=False),
175
- partial(brainstate.functional.gelu, approximate=True),
176
- brainstate.functional.relu,
177
- brainstate.functional.softplus,
178
- brainstate.functional.sparse_plus,
179
- brainstate.functional.sigmoid,
180
- # brainstate.functional.squareplus,
181
- brainstate.functional.mish)))
193
+ (partial(brainstate.nn.gelu, approximate=False),
194
+ partial(brainstate.nn.gelu, approximate=True),
195
+ brainstate.nn.relu,
196
+ brainstate.nn.softplus,
197
+ brainstate.nn.sparse_plus,
198
+ brainstate.nn.sigmoid,
199
+ # brainstate.nn.squareplus,
200
+ brainstate.nn.mish)))
182
201
  def testDtypeMatchesInput(self, dtype, fn):
183
202
  x = jnp.zeros((), dtype=dtype)
184
203
  out = fn(x)
@@ -187,26 +206,26 @@ class NNFunctionsTest(jtu.JaxTestCase):
187
206
  def testEluMemory(self):
188
207
  # see https://github.com/google/jax/pull/1640
189
208
  with jax.enable_checks(False): # With checks we materialize the array
190
- jax.make_jaxpr(lambda: brainstate.functional.elu(jnp.ones((10 ** 12,)))) # don't oom
209
+ jax.make_jaxpr(lambda: brainstate.nn.elu(jnp.ones((10 ** 12,)))) # don't oom
191
210
 
192
211
  def testHardTanhMemory(self):
193
212
  # see https://github.com/google/jax/pull/1640
194
213
  with jax.enable_checks(False): # With checks we materialize the array
195
- jax.make_jaxpr(lambda: brainstate.functional.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
214
+ jax.make_jaxpr(lambda: brainstate.nn.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
196
215
 
197
- @parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
216
+ @parameterized.parameters([brainstate.nn.softmax, brainstate.nn.log_softmax])
198
217
  def testSoftmaxEmptyArray(self, fn):
199
218
  x = jnp.array([], dtype=float)
200
219
  self.assertArraysEqual(fn(x), x)
201
220
 
202
- @parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
221
+ @parameterized.parameters([brainstate.nn.softmax, brainstate.nn.log_softmax])
203
222
  def testSoftmaxEmptyMask(self, fn):
204
223
  x = jnp.array([5.5, 1.3, -4.2, 0.9])
205
224
  m = jnp.zeros_like(x, dtype=bool)
206
- expected = jnp.full_like(x, 0.0 if fn is brainstate.functional.softmax else -jnp.inf)
225
+ expected = jnp.full_like(x, 0.0 if fn is brainstate.nn.softmax else -jnp.inf)
207
226
  self.assertArraysEqual(fn(x, where=m), expected)
208
227
 
209
- @parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
228
+ @parameterized.parameters([brainstate.nn.softmax, brainstate.nn.log_softmax])
210
229
  def testSoftmaxWhereMask(self, fn):
211
230
  x = jnp.array([5.5, 1.3, -4.2, 0.9])
212
231
  m = jnp.array([True, False, True, True])
@@ -214,10 +233,10 @@ class NNFunctionsTest(jtu.JaxTestCase):
214
233
  out = fn(x, where=m)
215
234
  self.assertAllClose(out[m], fn(x[m]))
216
235
 
217
- probs = out if fn is brainstate.functional.softmax else jnp.exp(out)
218
- self.assertAllClose(probs.sum(), 1.0)
236
+ probs = out if fn is brainstate.nn.softmax else jnp.exp(out)
237
+ self.assertAllClose(probs.sum(), 1.0, check_dtypes=False)
219
238
 
220
- @parameterized.parameters([brainstate.functional.softmax, brainstate.functional.log_softmax])
239
+ @parameterized.parameters([brainstate.nn.softmax, brainstate.nn.log_softmax])
221
240
  def testSoftmaxWhereGrad(self, fn):
222
241
  # regression test for https://github.com/google/jax/issues/19490
223
242
  x = jnp.array([36., 10000.])
@@ -229,46 +248,46 @@ class NNFunctionsTest(jtu.JaxTestCase):
229
248
 
230
249
  def testSoftmaxGrad(self):
231
250
  x = jnp.array([5.5, 1.3, -4.2, 0.9])
232
- jtu.check_grads(brainstate.functional.softmax, (x,), order=2, atol=5e-3)
251
+ check_grads(brainstate.nn.softmax, (x,), order=2, atol=5e-3)
233
252
 
234
253
  def testStandardizeWhereMask(self):
235
254
  x = jnp.array([5.5, 1.3, -4.2, 0.9])
236
255
  m = jnp.array([True, False, True, True])
237
256
  x_filtered = jnp.take(x, jnp.array([0, 2, 3]))
238
257
 
239
- out_masked = jnp.take(brainstate.functional.standardize(x, where=m), jnp.array([0, 2, 3]))
240
- out_filtered = brainstate.functional.standardize(x_filtered)
258
+ out_masked = jnp.take(brainstate.nn.standardize(x, where=m), jnp.array([0, 2, 3]))
259
+ out_filtered = brainstate.nn.standardize(x_filtered)
241
260
 
242
- self.assertAllClose(out_masked, out_filtered)
261
+ self.assertAllClose(out_masked, out_filtered, rtol=1e-6, atol=1e-6)
243
262
 
244
263
  def testOneHot(self):
245
- actual = brainstate.functional.one_hot(jnp.array([0, 1, 2]), 3)
264
+ actual = brainstate.nn.one_hot(jnp.array([0, 1, 2]), 3)
246
265
  expected = jnp.array([[1., 0., 0.],
247
266
  [0., 1., 0.],
248
267
  [0., 0., 1.]])
249
268
  self.assertAllClose(actual, expected, check_dtypes=False)
250
269
 
251
- actual = brainstate.functional.one_hot(jnp.array([1, 2, 0]), 3)
270
+ actual = brainstate.nn.one_hot(jnp.array([1, 2, 0]), 3)
252
271
  expected = jnp.array([[0., 1., 0.],
253
272
  [0., 0., 1.],
254
273
  [1., 0., 0.]])
255
274
  self.assertAllClose(actual, expected, check_dtypes=False)
256
275
 
257
276
  def testOneHotOutOfBound(self):
258
- actual = brainstate.functional.one_hot(jnp.array([-1, 3]), 3)
277
+ actual = brainstate.nn.one_hot(jnp.array([-1, 3]), 3)
259
278
  expected = jnp.array([[0., 0., 0.],
260
279
  [0., 0., 0.]])
261
280
  self.assertAllClose(actual, expected, check_dtypes=False)
262
281
 
263
282
  def testOneHotNonArrayInput(self):
264
- actual = brainstate.functional.one_hot([0, 1, 2], 3)
283
+ actual = brainstate.nn.one_hot([0, 1, 2], 3)
265
284
  expected = jnp.array([[1., 0., 0.],
266
285
  [0., 1., 0.],
267
286
  [0., 0., 1.]])
268
287
  self.assertAllClose(actual, expected, check_dtypes=False)
269
288
 
270
289
  def testOneHotCustomDtype(self):
271
- actual = brainstate.functional.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
290
+ actual = brainstate.nn.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
272
291
  expected = jnp.array([[True, False, False],
273
292
  [False, True, False],
274
293
  [False, False, True]])
@@ -279,14 +298,14 @@ class NNFunctionsTest(jtu.JaxTestCase):
279
298
  [0., 0., 1.],
280
299
  [1., 0., 0.]]).T
281
300
 
282
- actual = brainstate.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=0)
301
+ actual = brainstate.nn.one_hot(jnp.array([1, 2, 0]), 3, axis=0)
283
302
  self.assertAllClose(actual, expected, check_dtypes=False)
284
303
 
285
- actual = brainstate.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
304
+ actual = brainstate.nn.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
286
305
  self.assertAllClose(actual, expected, check_dtypes=False)
287
306
 
288
307
  def testTanhExists(self):
289
- print(brainstate.functional.tanh) # doesn't crash
308
+ print(brainstate.nn.tanh) # doesn't crash
290
309
 
291
310
  def testCustomJVPLeak(self):
292
311
  # https://github.com/google/jax/issues/8171
@@ -295,7 +314,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
295
314
  a = jnp.array(1.)
296
315
 
297
316
  def f(hx, _):
298
- hx = brainstate.functional.sigmoid(hx + a)
317
+ hx = brainstate.nn.sigmoid(hx + a)
299
318
  return hx, None
300
319
 
301
320
  hx = jnp.array(0.)
@@ -306,7 +325,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
306
325
 
307
326
  def testCustomJVPLeak2(self):
308
327
  # https://github.com/google/jax/issues/8171
309
- # The above test uses jax.brainstate.functional.sigmoid, as in the original #8171, but that
328
+ # The above test uses jax.brainstate.nn.sigmoid, as in the original #8171, but that
310
329
  # function no longer actually has a custom_jvp! So we inline the old def.
311
330
 
312
331
  @jax.custom_jvp
@@ -329,3 +348,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
329
348
 
330
349
  with jax.checking_leaks():
331
350
  fwd() # doesn't crash
351
+
352
+
353
+ if __name__ == '__main__':
354
+ absltest.main()