brainstate 0.1.0.post20250212__py2.py3-none-any.whl → 0.1.0.post20250216__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. brainstate/_state.py +853 -90
  2. brainstate/_state_test.py +1 -3
  3. brainstate/augment/__init__.py +2 -2
  4. brainstate/augment/_autograd.py +257 -115
  5. brainstate/augment/_autograd_test.py +2 -3
  6. brainstate/augment/_eval_shape.py +3 -4
  7. brainstate/augment/_mapping.py +582 -62
  8. brainstate/augment/_mapping_test.py +114 -30
  9. brainstate/augment/_random.py +61 -7
  10. brainstate/compile/_ad_checkpoint.py +2 -3
  11. brainstate/compile/_conditions.py +4 -5
  12. brainstate/compile/_conditions_test.py +1 -2
  13. brainstate/compile/_error_if.py +1 -2
  14. brainstate/compile/_error_if_test.py +1 -2
  15. brainstate/compile/_jit.py +23 -16
  16. brainstate/compile/_jit_test.py +1 -2
  17. brainstate/compile/_loop_collect_return.py +18 -10
  18. brainstate/compile/_loop_collect_return_test.py +1 -1
  19. brainstate/compile/_loop_no_collection.py +5 -5
  20. brainstate/compile/_make_jaxpr.py +23 -21
  21. brainstate/compile/_make_jaxpr_test.py +1 -2
  22. brainstate/compile/_progress_bar.py +1 -2
  23. brainstate/compile/_unvmap.py +1 -0
  24. brainstate/compile/_util.py +4 -2
  25. brainstate/environ.py +4 -4
  26. brainstate/environ_test.py +1 -2
  27. brainstate/functional/_activations.py +1 -2
  28. brainstate/functional/_activations_test.py +1 -1
  29. brainstate/functional/_normalization.py +1 -2
  30. brainstate/functional/_others.py +1 -2
  31. brainstate/functional/_spikes.py +136 -20
  32. brainstate/graph/_graph_node.py +2 -43
  33. brainstate/graph/_graph_operation.py +4 -20
  34. brainstate/graph/_graph_operation_test.py +3 -4
  35. brainstate/init/_base.py +1 -2
  36. brainstate/init/_generic.py +1 -2
  37. brainstate/nn/__init__.py +4 -0
  38. brainstate/nn/_collective_ops.py +351 -48
  39. brainstate/nn/_collective_ops_test.py +36 -0
  40. brainstate/nn/_common.py +194 -0
  41. brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
  42. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
  43. brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
  44. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
  45. brainstate/nn/_dyn_impl/_inputs.py +1 -2
  46. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
  47. brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
  48. brainstate/nn/_dyn_impl/_readout.py +2 -3
  49. brainstate/nn/_dyn_impl/_readout_test.py +1 -2
  50. brainstate/nn/_dynamics/_dynamics_base.py +2 -3
  51. brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
  52. brainstate/nn/_dynamics/_state_delay.py +3 -3
  53. brainstate/nn/_dynamics/_synouts_test.py +1 -2
  54. brainstate/nn/_elementwise/_dropout.py +6 -7
  55. brainstate/nn/_elementwise/_dropout_test.py +1 -2
  56. brainstate/nn/_elementwise/_elementwise.py +1 -2
  57. brainstate/nn/_exp_euler.py +1 -2
  58. brainstate/nn/_exp_euler_test.py +1 -2
  59. brainstate/nn/_interaction/_conv.py +1 -2
  60. brainstate/nn/_interaction/_conv_test.py +1 -0
  61. brainstate/nn/_interaction/_linear.py +1 -2
  62. brainstate/nn/_interaction/_linear_test.py +1 -2
  63. brainstate/nn/_interaction/_normalizations.py +1 -2
  64. brainstate/nn/_interaction/_poolings.py +3 -4
  65. brainstate/nn/_module.py +63 -19
  66. brainstate/nn/_module_test.py +1 -2
  67. brainstate/nn/metrics.py +3 -4
  68. brainstate/optim/_lr_scheduler.py +1 -2
  69. brainstate/optim/_lr_scheduler_test.py +2 -3
  70. brainstate/optim/_optax_optimizer_test.py +1 -2
  71. brainstate/optim/_sgd_optimizer.py +2 -3
  72. brainstate/random/_rand_funs.py +1 -2
  73. brainstate/random/_rand_funs_test.py +2 -3
  74. brainstate/random/_rand_seed.py +2 -3
  75. brainstate/random/_rand_seed_test.py +1 -2
  76. brainstate/random/_rand_state.py +3 -4
  77. brainstate/surrogate.py +5 -2
  78. brainstate/transform.py +0 -3
  79. brainstate/typing.py +28 -25
  80. brainstate/util/__init__.py +9 -7
  81. brainstate/util/_caller.py +1 -2
  82. brainstate/util/_error.py +27 -0
  83. brainstate/util/_others.py +60 -15
  84. brainstate/util/{_dict.py → _pretty_pytree.py} +2 -2
  85. brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
  86. brainstate/util/_pretty_repr.py +1 -2
  87. brainstate/util/_pretty_table.py +2900 -0
  88. brainstate/util/_struct.py +11 -11
  89. brainstate/util/filter.py +472 -0
  90. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250216.dist-info}/METADATA +2 -2
  91. brainstate-0.1.0.post20250216.dist-info/RECORD +127 -0
  92. brainstate/util/_filter.py +0 -178
  93. brainstate-0.1.0.post20250212.dist-info/RECORD +0 -124
  94. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250216.dist-info}/LICENSE +0 -0
  95. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250216.dist-info}/WHEEL +0 -0
  96. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250216.dist-info}/top_level.txt +0 -0
@@ -17,11 +17,10 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import unittest
21
-
22
20
  import brainunit as u
23
21
  import jax
24
22
  import jax.numpy as jnp
23
+ import unittest
25
24
 
26
25
  import brainstate as bst
27
26
  from brainstate.nn import IF, LIF, ALIF
@@ -17,9 +17,8 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- from typing import Optional
21
-
22
20
  import brainunit as u
21
+ from typing import Optional
23
22
 
24
23
  from brainstate import init, environ
25
24
  from brainstate._state import ShortTermState, HiddenState
@@ -15,11 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import brainunit as u
21
19
  import jax.numpy as jnp
22
20
  import pytest
21
+ import unittest
23
22
 
24
23
  import brainstate as bst
25
24
  from brainstate.nn import Expon, STP, STD
@@ -14,11 +14,10 @@
14
14
  # ==============================================================================
15
15
  from __future__ import annotations
16
16
 
17
- from typing import Union, Optional, Sequence, Callable
18
-
19
17
  import brainunit as u
20
18
  import jax
21
19
  import numpy as np
20
+ from typing import Union, Optional, Sequence, Callable
22
21
 
23
22
  from brainstate import environ, init, random
24
23
  from brainstate._state import ShortTermState
@@ -17,9 +17,8 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- from typing import Callable, Union
21
-
22
20
  import jax.numpy as jnp
21
+ from typing import Callable, Union
23
22
 
24
23
  from brainstate import random, init, functional
25
24
  from brainstate._state import HiddenState, ParamState
@@ -15,9 +15,8 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import jax.numpy as jnp
19
+ import unittest
21
20
 
22
21
  import brainstate as bst
23
22
 
@@ -17,11 +17,10 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import numbers
21
- from typing import Callable
22
-
23
20
  import brainunit as u
24
21
  import jax
22
+ import numbers
23
+ from typing import Callable
25
24
 
26
25
  from brainstate import environ, init, surrogate
27
26
  from brainstate._state import HiddenState, ParamState
@@ -15,9 +15,8 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import jax.numpy as jnp
19
+ import unittest
21
20
 
22
21
  import brainstate as bst
23
22
 
@@ -34,10 +34,9 @@ For handling the delays:
34
34
  """
35
35
  from __future__ import annotations
36
36
 
37
- from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
38
-
39
37
  import brainunit as u
40
38
  import numpy as np
39
+ from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
41
40
 
42
41
  from brainstate import environ
43
42
  from brainstate._state import State
@@ -420,7 +419,7 @@ class Dynamics(Module):
420
419
  else:
421
420
  raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
422
421
 
423
- def __leaf_fn__(self, name, value):
422
+ def __pretty_repr_item__(self, name, value):
424
423
  if name in ['_in_size', '_out_size', '_name', '_mode',
425
424
  '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs']:
426
425
  return (name, value) if value is None else (name[1:], value) # skip the first `_`
@@ -17,9 +17,8 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import unittest
21
-
22
20
  import numpy as np
21
+ import unittest
23
22
 
24
23
  import brainstate as bst
25
24
 
@@ -16,14 +16,14 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import math
19
- import numbers
20
- from functools import partial
21
- from typing import Optional, Dict, Callable, Union, Sequence
22
19
 
23
20
  import brainunit as u
24
21
  import jax
25
22
  import jax.numpy as jnp
23
+ import numbers
26
24
  import numpy as np
25
+ from functools import partial
26
+ from typing import Optional, Dict, Callable, Union, Sequence
27
27
 
28
28
  from brainstate import environ
29
29
  from brainstate._state import ShortTermState, State
@@ -15,11 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import brainunit as u
21
19
  import jax.numpy as jnp
22
20
  import numpy as np
21
+ import unittest
23
22
 
24
23
  import brainstate as bst
25
24
 
@@ -16,11 +16,10 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- from functools import partial
20
- from typing import Optional, Sequence
21
-
22
19
  import brainunit as u
23
20
  import jax.numpy as jnp
21
+ from functools import partial
22
+ from typing import Optional, Sequence
24
23
 
25
24
  from brainstate import random, environ, init
26
25
  from brainstate._state import ShortTermState
@@ -72,11 +71,11 @@ class Dropout(ElementWiseBlock):
72
71
  for dim in self.broadcast_dims:
73
72
  broadcast_shape[dim] = 1
74
73
  keep_mask = random.bernoulli(self.prob, broadcast_shape)
75
- keep_mask = jnp.broadcast_to(keep_mask, x.shape)
76
- return jnp.where(
74
+ keep_mask = u.math.broadcast_to(keep_mask, x.shape)
75
+ return u.math.where(
77
76
  keep_mask,
78
- jnp.asarray(x / self.prob, dtype=dtype),
79
- jnp.asarray(0., dtype=dtype)
77
+ u.math.asarray(x / self.prob, dtype=dtype),
78
+ u.math.asarray(0., dtype=dtype)
80
79
  )
81
80
  else:
82
81
  return x
@@ -14,9 +14,8 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- import unittest
18
-
19
17
  import numpy as np
18
+ import unittest
20
19
 
21
20
  import brainstate as bst
22
21
 
@@ -17,11 +17,10 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- from typing import Optional
21
-
22
20
  import brainunit as u
23
21
  import jax.numpy as jnp
24
22
  import jax.typing
23
+ from typing import Optional
25
24
 
26
25
  from brainstate import random, functional as F
27
26
  from brainstate._state import ParamState
@@ -16,10 +16,9 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- from typing import Callable
20
-
21
19
  import brainunit as u
22
20
  import jax.numpy as jnp
21
+ from typing import Callable
23
22
 
24
23
  from brainstate import environ, random
25
24
  from brainstate.augment import vector_grad
@@ -15,9 +15,8 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import brainunit as u
19
+ import unittest
21
20
 
22
21
  import brainstate as bst
23
22
 
@@ -18,10 +18,9 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import collections.abc
21
- from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
22
-
23
21
  import jax
24
22
  import jax.numpy as jnp
23
+ from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
25
24
 
26
25
  from brainstate import init, functional
27
26
  from brainstate._state import ParamState
@@ -235,5 +235,6 @@ class TestConvTranspose3d(parameterized.TestCase):
235
235
  y = conv_transpose_module(x)
236
236
  print(y.shape)
237
237
 
238
+
238
239
  if __name__ == '__main__':
239
240
  absltest.main()
@@ -17,10 +17,9 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- from typing import Callable, Union, Optional
21
-
22
20
  import brainunit as u
23
21
  import jax.numpy as jnp
22
+ from typing import Callable, Union, Optional
24
23
 
25
24
  from brainstate import init, functional
26
25
  from brainstate._state import ParamState
@@ -16,9 +16,8 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- import unittest
20
-
21
19
  import brainunit as u
20
+ import unittest
22
21
  from absl.testing import parameterized
23
22
 
24
23
  import brainstate as bst
@@ -17,10 +17,9 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- from typing import Callable, Union, Sequence, Optional, Any, Dict
21
-
22
20
  import jax
23
21
  import jax.numpy as jnp
22
+ from typing import Callable, Union, Sequence, Optional, Any
24
23
 
25
24
  from brainstate import environ, init
26
25
  from brainstate._state import ParamState, BatchState
@@ -17,14 +17,13 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import functools
21
- from typing import Sequence, Optional
22
- from typing import Union, Tuple, Callable, List
23
-
24
20
  import brainunit as u
21
+ import functools
25
22
  import jax
26
23
  import jax.numpy as jnp
27
24
  import numpy as np
25
+ from typing import Sequence, Optional
26
+ from typing import Union, Tuple, Callable, List
28
27
 
29
28
  from brainstate import environ
30
29
  from brainstate.nn._module import Module
brainstate/nn/_module.py CHANGED
@@ -27,16 +27,15 @@ The basic classes include:
27
27
  """
28
28
  from __future__ import annotations
29
29
 
30
- import warnings
31
- from typing import Sequence, Optional, Tuple, Union, TYPE_CHECKING
32
-
33
30
  import numpy as np
31
+ import warnings
32
+ from typing import Sequence, Optional, Tuple, Union, TYPE_CHECKING, Callable
34
33
 
35
34
  from brainstate._state import State
36
35
  from brainstate.graph import Node, states, nodes, flatten
37
36
  from brainstate.mixin import ParamDescriber, ParamDesc
38
37
  from brainstate.typing import PathParts
39
- from brainstate.util import FlattedDict, NestedDict
38
+ from brainstate.util import FlattedDict, NestedDict, BrainStateError
40
39
 
41
40
  # maximum integer
42
41
  max_int = np.iinfo(np.int32).max
@@ -226,7 +225,7 @@ class Module(Node, ParamDesc):
226
225
  """
227
226
  pass
228
227
 
229
- def __leaf_fn__(self, name, value):
228
+ def __pretty_repr_item__(self, name, value):
230
229
  if name in ['_in_size', '_out_size', '_name']:
231
230
  return (name, value) if value is None else (name[1:], value) # skip the first `_`
232
231
  return name, value
@@ -288,7 +287,7 @@ class Sequential(Module):
288
287
  in_size = first.out_size
289
288
  self.layers.append(first)
290
289
  for module in layers:
291
- module, in_size = _format_module(module, in_size)
290
+ module, in_size = self._format_module(module, in_size)
292
291
  self.layers.append(module)
293
292
 
294
293
  # the input and output shape
@@ -301,7 +300,14 @@ class Sequential(Module):
301
300
  """Update function of a sequential model.
302
301
  """
303
302
  for m in self.layers:
304
- x = m(x)
303
+ try:
304
+ x = m(x)
305
+ except Exception as e:
306
+ raise BrainStateError(
307
+ f'The module \n'
308
+ f'{m}\n'
309
+ f'failed to update with input {x}\n'
310
+ ) from e
305
311
  return x
306
312
 
307
313
  def __getitem__(self, key: Union[int, slice]):
@@ -314,16 +320,54 @@ class Sequential(Module):
314
320
  else:
315
321
  raise KeyError(f'Unknown type of key: {type(key)}')
316
322
 
323
+ def append(self, layer: Callable):
324
+ """
325
+ Append a layer to the sequential model.
326
+
327
+ This method adds a new layer to the end of the sequential model. The layer can be
328
+ either a Module instance, an ElementWiseBlock instance, or a callable function. If the
329
+ layer is a callable function, it will be wrapped in an ElementWiseBlock instance.
330
+
331
+ Parameters:
332
+ ----------
333
+ layer : Callable
334
+ The layer to be appended to the sequential model. It can be a Module instance,
335
+ an ElementWiseBlock instance, or a callable function.
336
+
337
+ Raises:
338
+ -------
339
+ ValueError
340
+ If the sequential model is empty and the first layer is a callable function.
317
341
 
318
- def _format_module(module, in_size):
319
- if isinstance(module, ParamDescriber):
320
- module = module(in_size=in_size)
321
- assert isinstance(module, Module), 'The module should be an instance of Module.'
322
- out_size = module.out_size
323
- elif isinstance(module, ElementWiseBlock):
324
- out_size = in_size
325
- elif isinstance(module, Module):
326
- out_size = module.out_size
327
- else:
328
- raise TypeError(f"Unsupported type {type(module)}. ")
329
- return module, out_size
342
+ Returns:
343
+ --------
344
+ None
345
+ The method does not return any value. It modifies the sequential model by adding
346
+ the new layer to the end.
347
+ """
348
+ if len(self.layers) == 0:
349
+ raise ValueError('The first layer should be a module, not a function.')
350
+ module, in_size = self._format_module(layer, self.out_size)
351
+ self.layers.append(module)
352
+ self.out_size = in_size
353
+
354
+ def _format_module(self, module, in_size):
355
+ if isinstance(module, ParamDescriber):
356
+ if in_size is None:
357
+ raise ValueError(
358
+ 'The input size should be specified. '
359
+ f'Please set the in_size attribute of the previous module: \n'
360
+ f'{self.layers[-1]}'
361
+ )
362
+ module = module(in_size=in_size)
363
+ assert isinstance(module, Module), 'The module should be an instance of Module.'
364
+ out_size = module.out_size
365
+ elif isinstance(module, ElementWiseBlock):
366
+ out_size = in_size
367
+ elif isinstance(module, Module):
368
+ out_size = module.out_size
369
+ elif callable(module):
370
+ out_size = in_size
371
+ else:
372
+ raise TypeError(f"Unsupported type {type(module)}. ")
373
+ return module, out_size
@@ -15,10 +15,9 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import jax.numpy as jnp
21
19
  import jaxlib.xla_extension
20
+ import unittest
22
21
 
23
22
  import brainstate as bst
24
23
 
brainstate/nn/metrics.py CHANGED
@@ -16,13 +16,12 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- import typing as tp
20
- from dataclasses import dataclass
21
- from functools import partial
22
-
23
19
  import jax
24
20
  import jax.numpy as jnp
25
21
  import numpy as np
22
+ import typing as tp
23
+ from dataclasses import dataclass
24
+ from functools import partial
26
25
 
27
26
  from brainstate._state import State
28
27
 
@@ -16,11 +16,10 @@
16
16
  # -*- coding: utf-8 -*-
17
17
  from __future__ import annotations
18
18
 
19
- from typing import Sequence, Union
20
-
21
19
  import jax
22
20
  import jax.numpy as jnp
23
21
  import numpy as np
22
+ from typing import Sequence, Union
24
23
 
25
24
  from brainstate import environ
26
25
  from brainstate._state import State, LongTermState
@@ -15,9 +15,8 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import jax.numpy as jnp
19
+ import unittest
21
20
 
22
21
  import brainstate as bst
23
22
 
@@ -37,7 +36,7 @@ class TestMultiStepLR(unittest.TestCase):
37
36
  self.assertTrue(jnp.allclose(r, 0.0001))
38
37
 
39
38
  def test2(self):
40
- lr = bst.transform.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
39
+ lr = bst.compile.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
41
40
  for i in range(40):
42
41
  r = lr(i)
43
42
  if i < 10:
@@ -15,10 +15,9 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import jax
21
19
  import optax
20
+ import unittest
22
21
 
23
22
  import brainstate as bst
24
23
 
@@ -16,12 +16,11 @@
16
16
  # -*- coding: utf-8 -*-
17
17
  from __future__ import annotations
18
18
 
19
- import functools
20
- from typing import Union, Dict, Optional, Tuple, Any, TypeVar
21
-
22
19
  import brainunit as u
20
+ import functools
23
21
  import jax
24
22
  import jax.numpy as jnp
23
+ from typing import Union, Dict, Optional, Tuple, Any, TypeVar
25
24
 
26
25
  from brainstate import environ
27
26
  from brainstate._state import State, LongTermState, StateDictManager
@@ -17,9 +17,8 @@
17
17
  # -*- coding: utf-8 -*-
18
18
  from __future__ import annotations
19
19
 
20
- from typing import Optional
21
-
22
20
  import numpy as np
21
+ from typing import Optional
23
22
 
24
23
  from brainstate.typing import DTypeLike, Size, SeedOrKey
25
24
  from ._rand_state import RandomState, DEFAULT
@@ -15,13 +15,12 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import platform
19
- import unittest
20
-
21
18
  import jax.numpy as jnp
22
19
  import jax.random as jr
23
20
  import numpy as np
21
+ import platform
24
22
  import pytest
23
+ import unittest
25
24
 
26
25
  import brainstate as bst
27
26
 
@@ -14,11 +14,10 @@
14
14
  # ==============================================================================
15
15
  from __future__ import annotations
16
16
 
17
- from contextlib import contextmanager
18
- from typing import Optional
19
-
20
17
  import jax
21
18
  import numpy as np
19
+ from contextlib import contextmanager
20
+ from typing import Optional
22
21
 
23
22
  from brainstate.typing import SeedOrKey
24
23
  from ._rand_state import RandomState, DEFAULT, use_prng_key
@@ -15,10 +15,9 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import jax.numpy as jnp
21
19
  import jax.random
20
+ import unittest
22
21
 
23
22
  import brainstate as bst
24
23
 
@@ -16,17 +16,16 @@
16
16
  # -*- coding: utf-8 -*-
17
17
  from __future__ import annotations
18
18
 
19
- from functools import partial
20
- from operator import index
21
- from typing import Optional
22
-
23
19
  import brainunit as u
24
20
  import jax
25
21
  import jax.numpy as jnp
26
22
  import jax.random as jr
27
23
  import numpy as np
24
+ from functools import partial
28
25
  from jax import jit, vmap
29
26
  from jax import lax, core, dtypes
27
+ from operator import index
28
+ from typing import Optional
30
29
 
31
30
  from brainstate import environ
32
31
  from brainstate._state import State
brainstate/surrogate.py CHANGED
@@ -21,7 +21,7 @@ import jax.numpy as jnp
21
21
  import jax.scipy as sci
22
22
  from jax.interpreters import batching, ad, mlir
23
23
 
24
- from brainstate.util._dict import PrettyObject
24
+ from brainstate.util._pretty_pytree import PrettyObject
25
25
 
26
26
  if jax.__version_info__ < (0, 4, 38):
27
27
  from jax.core import Primitive
@@ -79,7 +79,10 @@ def _heaviside_imp(x, dx):
79
79
 
80
80
 
81
81
  def _heaviside_batching(args, axes):
82
- return heaviside_p.bind(*args), tuple(axes)
82
+ x, dx = args
83
+ if axes[0] != axes[1]:
84
+ dx = batching.moveaxis(dx, axes[1], axes[0])
85
+ return heaviside_p.bind(x, dx), tuple([axes[0]])
83
86
 
84
87
 
85
88
  def _heaviside_jvp(primals, tangents):
brainstate/transform.py CHANGED
@@ -15,6 +15,3 @@
15
15
 
16
16
  # alias for compilation and augmentation functions
17
17
 
18
- from .compile import *
19
- from .augment import *
20
-