brainstate 0.1.1__py2.py3-none-any.whl → 0.1.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +12 -9
  3. brainstate/_state.py +1 -1
  4. brainstate/augment/_autograd_test.py +132 -133
  5. brainstate/augment/_eval_shape_test.py +7 -9
  6. brainstate/augment/_mapping_test.py +75 -76
  7. brainstate/compile/_ad_checkpoint_test.py +6 -8
  8. brainstate/compile/_conditions_test.py +35 -36
  9. brainstate/compile/_error_if_test.py +10 -13
  10. brainstate/compile/_loop_collect_return_test.py +7 -9
  11. brainstate/compile/_loop_no_collection_test.py +7 -8
  12. brainstate/compile/_make_jaxpr.py +29 -14
  13. brainstate/compile/_make_jaxpr_test.py +20 -20
  14. brainstate/functional/_activations_test.py +61 -61
  15. brainstate/graph/_graph_node_test.py +16 -18
  16. brainstate/graph/_graph_operation_test.py +154 -156
  17. brainstate/init/_random_inits_test.py +20 -21
  18. brainstate/init/_regular_inits_test.py +4 -5
  19. brainstate/mixin.py +1 -14
  20. brainstate/nn/__init__.py +81 -17
  21. brainstate/nn/_collective_ops_test.py +8 -8
  22. brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
  23. brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +31 -33
  24. brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +6 -2
  25. brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +1 -1
  26. brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +9 -9
  27. brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +139 -18
  28. brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +14 -15
  29. brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
  30. brainstate/nn/_elementwise_test.py +169 -0
  31. brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
  32. brainstate/nn/_exp_euler_test.py +5 -6
  33. brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob_mv.py} +1 -1
  34. brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_mv_test.py} +0 -1
  35. brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
  36. brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
  37. brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +1 -1
  38. brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -1
  39. brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +15 -17
  40. brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
  41. brainstate/nn/_module_test.py +34 -37
  42. brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
  43. brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +18 -19
  44. brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
  45. brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +10 -12
  46. brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
  47. brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +20 -22
  48. brainstate/nn/{_dynamics/_projection_base.py → _projection.py} +35 -3
  49. brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
  50. brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +6 -7
  51. brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
  52. brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +9 -10
  53. brainstate/nn/_stp.py +236 -0
  54. brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +17 -206
  55. brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +9 -10
  56. brainstate/nn/_synaptic_projection.py +133 -0
  57. brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
  58. brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +4 -5
  59. brainstate/optim/_lr_scheduler_test.py +3 -3
  60. brainstate/optim/_optax_optimizer_test.py +8 -9
  61. brainstate/random/_rand_funs_test.py +183 -184
  62. brainstate/random/_rand_seed_test.py +10 -12
  63. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/METADATA +1 -1
  64. brainstate-0.1.3.dist-info/RECORD +131 -0
  65. brainstate/nn/_dyn_impl/__init__.py +0 -42
  66. brainstate/nn/_dynamics/__init__.py +0 -37
  67. brainstate/nn/_elementwise/__init__.py +0 -22
  68. brainstate/nn/_elementwise/_elementwise_test.py +0 -171
  69. brainstate/nn/_interaction/__init__.py +0 -41
  70. brainstate-0.1.1.dist-info/RECORD +0 -133
  71. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/LICENSE +0 -0
  72. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/WHEEL +0 -0
  73. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/top_level.txt +0 -0
@@ -18,19 +18,19 @@ import unittest
18
18
 
19
19
  import numpy as np
20
20
 
21
- import brainstate as bst
21
+ import brainstate
22
22
 
23
23
 
24
24
  class TestDropout(unittest.TestCase):
25
25
 
26
26
  def test_dropout(self):
27
27
  # Create a Dropout layer with a dropout rate of 0.5
28
- dropout_layer = bst.nn.Dropout(0.5)
28
+ dropout_layer = brainstate.nn.Dropout(0.5)
29
29
 
30
30
  # Input data
31
31
  input_data = np.arange(20)
32
32
 
33
- with bst.environ.context(fit=True):
33
+ with brainstate.environ.context(fit=True):
34
34
  # Apply dropout
35
35
  output_data = dropout_layer(input_data)
36
36
 
@@ -47,10 +47,10 @@ class TestDropout(unittest.TestCase):
47
47
  np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements)
48
48
 
49
49
  def test_DropoutFixed(self):
50
- dropout_layer = bst.nn.DropoutFixed(in_size=(2, 3), prob=0.5)
50
+ dropout_layer = brainstate.nn.DropoutFixed(in_size=(2, 3), prob=0.5)
51
51
  dropout_layer.init_state(batch_size=2)
52
52
  input_data = np.random.randn(2, 2, 3)
53
- with bst.environ.context(fit=True):
53
+ with brainstate.environ.context(fit=True):
54
54
  output_data = dropout_layer.update(input_data)
55
55
  self.assertEqual(input_data.shape, output_data.shape)
56
56
  self.assertTrue(np.any(output_data == 0))
@@ -72,9 +72,9 @@ class TestDropout(unittest.TestCase):
72
72
  # np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
73
73
 
74
74
  def test_Dropout2d(self):
75
- dropout_layer = bst.nn.Dropout2d(prob=0.5)
75
+ dropout_layer = brainstate.nn.Dropout2d(prob=0.5)
76
76
  input_data = np.random.randn(2, 3, 4, 5)
77
- with bst.environ.context(fit=True):
77
+ with brainstate.environ.context(fit=True):
78
78
  output_data = dropout_layer(input_data)
79
79
  self.assertEqual(input_data.shape, output_data.shape)
80
80
  self.assertTrue(np.any(output_data == 0))
@@ -84,9 +84,9 @@ class TestDropout(unittest.TestCase):
84
84
  np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
85
85
 
86
86
  def test_Dropout3d(self):
87
- dropout_layer = bst.nn.Dropout3d(prob=0.5)
87
+ dropout_layer = brainstate.nn.Dropout3d(prob=0.5)
88
88
  input_data = np.random.randn(2, 3, 4, 5, 6)
89
- with bst.environ.context(fit=True):
89
+ with brainstate.environ.context(fit=True):
90
90
  output_data = dropout_layer(input_data)
91
91
  self.assertEqual(input_data.shape, output_data.shape)
92
92
  self.assertTrue(np.any(output_data == 0))
@@ -41,13 +41,14 @@ import numpy as np
41
41
  from brainstate import environ
42
42
  from brainstate._state import State
43
43
  from brainstate.graph import Node
44
- from brainstate.mixin import ParamDescriber
45
- from brainstate.nn._module import Module
44
+ from brainstate.mixin import ParamDescriber, UpdateReturn
46
45
  from brainstate.typing import Size, ArrayLike
47
- from ._state_delay import StateWithDelay, Delay
46
+ from ._delay import StateWithDelay, Delay
47
+ from ._module import Module
48
48
 
49
49
  __all__ = [
50
- 'DynamicsGroup', 'Projection', 'Dynamics', 'Prefetch',
50
+ 'DynamicsGroup', 'Projection', 'Dynamics',
51
+ 'Prefetch', 'PrefetchDelay', 'PrefetchDelayAt', 'OutputDelayAt',
51
52
  ]
52
53
 
53
54
  T = TypeVar('T')
@@ -99,7 +100,7 @@ class Projection(Module):
99
100
  raise ValueError('Do not implement the update() function.')
100
101
 
101
102
 
102
- class Dynamics(Module):
103
+ class Dynamics(Module, UpdateReturn):
103
104
  """
104
105
  Base class for implementing neural dynamics models in BrainState.
105
106
 
@@ -821,6 +822,41 @@ class Dynamics(Module):
821
822
  else:
822
823
  raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
823
824
 
825
+ def prefetch_delay(self, state: str, delay: Optional[ArrayLike] = None) -> 'PrefetchDelayAt':
826
+ """
827
+ Create a reference to a delayed state or variable in the module.
828
+
829
+ This method simplifies the process of accessing a delayed version of a state or variable
830
+ within the module. It first creates a prefetch reference to the specified state,
831
+ then specifies the delay time for accessing this state.
832
+
833
+ Args:
834
+ state (str): The name of the state or variable to reference.
835
+ delay (Optional[ArrayLike]): The amount of time to delay the variable access,
836
+ typically in time units (e.g., milliseconds). Defaults to None.
837
+
838
+ Returns:
839
+ PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
840
+ """
841
+ return self.prefetch(state).delay.at(delay)
842
+
843
+ def output_delay(self, delay: Optional[ArrayLike] = None) -> 'OutputDelayAt':
844
+ """
845
+ Create a reference to the delayed output of the module.
846
+
847
+ This method simplifies the process of accessing a delayed version of the module's output.
848
+ It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
849
+ at the specified delay time.
850
+
851
+ Args:
852
+ delay (Optional[ArrayLike]): The amount of time to delay the output access,
853
+ typically in time units (e.g., milliseconds). Defaults to None.
854
+
855
+ Returns:
856
+ OutputDelayAt: An object that provides access to the module's output at the specified delay time.
857
+ """
858
+ return OutputDelayAt(self, delay)
859
+
824
860
 
825
861
  class Prefetch(Node):
826
862
  """
@@ -885,6 +921,7 @@ class Prefetch(Node):
885
921
  An object that provides access to delayed versions of the prefetched item.
886
922
  """
887
923
  return PrefetchDelay(self.module, self.item)
924
+ # return PrefetchDelayAt(self.module, self.item, time)
888
925
 
889
926
  def __call__(self, *args, **kwargs):
890
927
  """
@@ -1007,7 +1044,7 @@ class PrefetchDelayAt(Node):
1007
1044
  self,
1008
1045
  module: Dynamics,
1009
1046
  item: str,
1010
- time: ArrayLike
1047
+ time: ArrayLike = None,
1011
1048
  ):
1012
1049
  """
1013
1050
  Initialize a PrefetchDelayAt object.
@@ -1026,14 +1063,16 @@ class PrefetchDelayAt(Node):
1026
1063
  self.module = module
1027
1064
  self.item = item
1028
1065
  self.time = time
1029
- self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
1030
1066
 
1031
- # register the delay
1032
- key = _get_delay_key(item)
1033
- if not module._has_after_update(key):
1034
- module._add_after_update(key, not_receive_update_output(StateWithDelay(module, item)))
1035
- self.state_delay: StateWithDelay = module._get_after_update(key)
1036
- self.state_delay.register_delay(time)
1067
+ if time is not None:
1068
+ self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
1069
+
1070
+ # register the delay
1071
+ key = _get_prefetch_delay_key(item)
1072
+ if not module._has_after_update(key):
1073
+ module._add_after_update(key, not_receive_update_output(StateWithDelay(module, item)))
1074
+ self.state_delay: StateWithDelay = module._get_after_update(key)
1075
+ self.state_delay.register_delay(time)
1037
1076
 
1038
1077
  def __call__(self, *args, **kwargs):
1039
1078
  """
@@ -1044,12 +1083,94 @@ class PrefetchDelayAt(Node):
1044
1083
  Any
1045
1084
  The value of the state or variable at the specified delay time.
1046
1085
  """
1047
- # return self.state_delay.retrieve_at_time(self.time)
1048
- return self.state_delay.retrieve_at_step(self.step)
1086
+ if self.time is None:
1087
+ return _get_prefetch_item(self).value
1088
+ else:
1089
+ return self.state_delay.retrieve_at_step(self.step)
1090
+
1091
+
1092
+ class OutputDelayAt(Node):
1093
+ """
1094
+ Provides access to a specific delayed state or variable value at the specific time.
1095
+
1096
+ This class represents the final step in the prefetch delay chain, providing
1097
+ actual access to state values at a specific delay time. It converts the
1098
+ specified time delay into steps and registers the delay with the appropriate
1099
+ StateWithDelay handler.
1100
+
1101
+ Parameters
1102
+ ----------
1103
+ module : Dynamics
1104
+ The dynamics module that contains the referenced state or variable.
1105
+ item : str
1106
+ The name of the state or variable to access with delay.
1107
+ time : ArrayLike
1108
+ The amount of time to delay access by, typically in time units (e.g., milliseconds).
1109
+
1110
+ Examples
1111
+ --------
1112
+ >>> import brainstate
1113
+ >>> import brainunit as u
1114
+ >>> neuron = brainstate.nn.LIF(10)
1115
+ >>> # Create a reference to voltage delayed by 5ms
1116
+ >>> delayed_v = PrefetchDelayAt(neuron, 'V', 5.0 * u.ms)
1117
+ >>> # Get the delayed value
1118
+ >>> v_value = delayed_v()
1119
+ """
1120
+
1121
+ def __init__(
1122
+ self,
1123
+ module: Dynamics,
1124
+ time: Optional[ArrayLike] = None,
1125
+ ):
1126
+ """
1127
+ Initialize a PrefetchDelayAt object.
1128
+
1129
+ Parameters
1130
+ ----------
1131
+ module : AlignPre, Module
1132
+ The dynamics module that contains the referenced state or variable.
1133
+ time : ArrayLike
1134
+ The amount of time to delay access by, typically in time units.
1135
+ """
1136
+ super().__init__()
1137
+ assert isinstance(module, UpdateReturn), 'The module should implement the `update_return` method.'
1138
+ assert isinstance(module, Module), 'The module should be an instance of Module.'
1139
+ self.module = module
1140
+ self.time = time
1141
+ if time is not None:
1142
+ self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
1143
+
1144
+ # register the delay
1145
+ key = _get_output_delay_key()
1146
+ if not module._has_after_update(key):
1147
+ # TODO: unit processing
1148
+ delay = Delay(module.update_return(), time)
1149
+ module._add_after_update(key, receive_update_output(delay))
1150
+ self.out_delay: Delay = module._get_after_update(key)
1151
+ self.out_delay.register_delay(time)
1152
+
1153
+ def __call__(self, *args, **kwargs):
1154
+ """
1155
+ Retrieve the value of the state at the specified delay time.
1156
+
1157
+ Returns
1158
+ -------
1159
+ Any
1160
+ The value of the state or variable at the specified delay time.
1161
+ """
1162
+ if self.time is None:
1163
+ return self.module.update_return()
1164
+ else:
1165
+ return self.out_delay.retrieve_at_step(self.step)
1166
+
1167
+
1168
+ def _get_prefetch_delay_key(item) -> str:
1169
+ return f'{item}-prefetch-delay'
1049
1170
 
1050
1171
 
1051
- def _get_delay_key(item) -> str:
1052
- return f'{item}-delay'
1172
+ def _get_output_delay_key() -> str:
1173
+ return f'output-delay'
1053
1174
 
1054
1175
 
1055
1176
  def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
@@ -1064,7 +1185,7 @@ def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDela
1064
1185
  f'The target module should be an instance '
1065
1186
  f'of Dynamics. But got {target.module}.'
1066
1187
  )
1067
- delay = target.module._get_after_update(_get_delay_key(target.item))
1188
+ delay = target.module._get_after_update(_get_prefetch_delay_key(target.item))
1068
1189
  if not isinstance(delay, StateWithDelay):
1069
1190
  raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
1070
1191
  f'its delay. But got {delay}.')
@@ -15,47 +15,46 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from __future__ import annotations
19
18
 
20
19
  import unittest
21
20
 
22
21
  import numpy as np
23
22
 
24
- import brainstate as bst
23
+ import brainstate
25
24
 
26
25
 
27
26
  class TestModuleGroup(unittest.TestCase):
28
27
  def test_initialization(self):
29
- group = bst.nn.DynamicsGroup()
30
- self.assertIsInstance(group, bst.nn.DynamicsGroup)
28
+ group = brainstate.nn.DynamicsGroup()
29
+ self.assertIsInstance(group, brainstate.nn.DynamicsGroup)
31
30
 
32
31
 
33
32
  class TestProjection(unittest.TestCase):
34
33
  def test_initialization(self):
35
- proj = bst.nn.Projection()
36
- self.assertIsInstance(proj, bst.nn.Projection)
34
+ proj = brainstate.nn.Projection()
35
+ self.assertIsInstance(proj, brainstate.nn.Projection)
37
36
 
38
37
  def test_update_not_implemented(self):
39
- proj = bst.nn.Projection()
38
+ proj = brainstate.nn.Projection()
40
39
  with self.assertRaises(ValueError):
41
40
  proj.update()
42
41
 
43
42
 
44
43
  class TestDynamics(unittest.TestCase):
45
44
  def test_initialization(self):
46
- dyn = bst.nn.Dynamics(in_size=10)
47
- self.assertIsInstance(dyn, bst.nn.Dynamics)
45
+ dyn = brainstate.nn.Dynamics(in_size=10)
46
+ self.assertIsInstance(dyn, brainstate.nn.Dynamics)
48
47
  self.assertEqual(dyn.in_size, (10,))
49
48
  self.assertEqual(dyn.out_size, (10,))
50
49
 
51
50
  def test_size_validation(self):
52
51
  with self.assertRaises(ValueError):
53
- bst.nn.Dynamics(in_size=[])
52
+ brainstate.nn.Dynamics(in_size=[])
54
53
  with self.assertRaises(ValueError):
55
- bst.nn.Dynamics(in_size="invalid")
54
+ brainstate.nn.Dynamics(in_size="invalid")
56
55
 
57
56
  def test_input_handling(self):
58
- dyn = bst.nn.Dynamics(in_size=10)
57
+ dyn = brainstate.nn.Dynamics(in_size=10)
59
58
  dyn.add_current_input("test_current", lambda: np.random.rand(10))
60
59
  dyn.add_delta_input("test_delta", lambda: np.random.rand(10))
61
60
 
@@ -63,15 +62,15 @@ class TestDynamics(unittest.TestCase):
63
62
  self.assertIn("test_delta", dyn.delta_inputs)
64
63
 
65
64
  def test_duplicate_input_key(self):
66
- dyn = bst.nn.Dynamics(in_size=10)
65
+ dyn = brainstate.nn.Dynamics(in_size=10)
67
66
  dyn.add_current_input("test", lambda: np.random.rand(10))
68
67
  with self.assertRaises(ValueError):
69
68
  dyn.add_current_input("test", lambda: np.random.rand(10))
70
69
 
71
70
  def test_varshape(self):
72
- dyn = bst.nn.Dynamics(in_size=(2, 3))
71
+ dyn = brainstate.nn.Dynamics(in_size=(2, 3))
73
72
  self.assertEqual(dyn.varshape, (2, 3))
74
- dyn = bst.nn.Dynamics(in_size=(2, 3))
73
+ dyn = brainstate.nn.Dynamics(in_size=(2, 3))
75
74
  self.assertEqual(dyn.varshape, (2, 3))
76
75
 
77
76
 
@@ -23,8 +23,8 @@ import jax.typing
23
23
 
24
24
  from brainstate import random, functional as F
25
25
  from brainstate._state import ParamState
26
- from brainstate.nn._module import ElementWiseBlock
27
26
  from brainstate.typing import ArrayLike
27
+ from ._module import ElementWiseBlock
28
28
 
29
29
  __all__ = [
30
30
  # activation functions
@@ -0,0 +1,169 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from absl.testing import absltest
17
+ from absl.testing import parameterized
18
+
19
+ import brainstate
20
+
21
+
22
+ class Test_Activation(parameterized.TestCase):
23
+
24
+ def test_Threshold(self):
25
+ threshold_layer = brainstate.nn.Threshold(5, 20)
26
+ input = brainstate.random.randn(2)
27
+ output = threshold_layer(input)
28
+
29
+ def test_ReLU(self):
30
+ ReLU_layer = brainstate.nn.ReLU()
31
+ input = brainstate.random.randn(2)
32
+ output = ReLU_layer(input)
33
+
34
+ def test_RReLU(self):
35
+ RReLU_layer = brainstate.nn.RReLU(lower=0, upper=1)
36
+ input = brainstate.random.randn(2)
37
+ output = RReLU_layer(input)
38
+
39
+ def test_Hardtanh(self):
40
+ Hardtanh_layer = brainstate.nn.Hardtanh(min_val=0, max_val=1, )
41
+ input = brainstate.random.randn(2)
42
+ output = Hardtanh_layer(input)
43
+
44
+ def test_ReLU6(self):
45
+ ReLU6_layer = brainstate.nn.ReLU6()
46
+ input = brainstate.random.randn(2)
47
+ output = ReLU6_layer(input)
48
+
49
+ def test_Sigmoid(self):
50
+ Sigmoid_layer = brainstate.nn.Sigmoid()
51
+ input = brainstate.random.randn(2)
52
+ output = Sigmoid_layer(input)
53
+
54
+ def test_Hardsigmoid(self):
55
+ Hardsigmoid_layer = brainstate.nn.Hardsigmoid()
56
+ input = brainstate.random.randn(2)
57
+ output = Hardsigmoid_layer(input)
58
+
59
+ def test_Tanh(self):
60
+ Tanh_layer = brainstate.nn.Tanh()
61
+ input = brainstate.random.randn(2)
62
+ output = Tanh_layer(input)
63
+
64
+ def test_SiLU(self):
65
+ SiLU_layer = brainstate.nn.SiLU()
66
+ input = brainstate.random.randn(2)
67
+ output = SiLU_layer(input)
68
+
69
+ def test_Mish(self):
70
+ Mish_layer = brainstate.nn.Mish()
71
+ input = brainstate.random.randn(2)
72
+ output = Mish_layer(input)
73
+
74
+ def test_Hardswish(self):
75
+ Hardswish_layer = brainstate.nn.Hardswish()
76
+ input = brainstate.random.randn(2)
77
+ output = Hardswish_layer(input)
78
+
79
+ def test_ELU(self):
80
+ ELU_layer = brainstate.nn.ELU(alpha=0.5, )
81
+ input = brainstate.random.randn(2)
82
+ output = ELU_layer(input)
83
+
84
+ def test_CELU(self):
85
+ CELU_layer = brainstate.nn.CELU(alpha=0.5, )
86
+ input = brainstate.random.randn(2)
87
+ output = CELU_layer(input)
88
+
89
+ def test_SELU(self):
90
+ SELU_layer = brainstate.nn.SELU()
91
+ input = brainstate.random.randn(2)
92
+ output = SELU_layer(input)
93
+
94
+ def test_GLU(self):
95
+ GLU_layer = brainstate.nn.GLU()
96
+ input = brainstate.random.randn(4, 2)
97
+ output = GLU_layer(input)
98
+
99
+ @parameterized.product(
100
+ approximate=['tanh', 'none']
101
+ )
102
+ def test_GELU(self, approximate):
103
+ GELU_layer = brainstate.nn.GELU()
104
+ input = brainstate.random.randn(2)
105
+ output = GELU_layer(input)
106
+
107
+ def test_Hardshrink(self):
108
+ Hardshrink_layer = brainstate.nn.Hardshrink(lambd=1)
109
+ input = brainstate.random.randn(2)
110
+ output = Hardshrink_layer(input)
111
+
112
+ def test_LeakyReLU(self):
113
+ LeakyReLU_layer = brainstate.nn.LeakyReLU()
114
+ input = brainstate.random.randn(2)
115
+ output = LeakyReLU_layer(input)
116
+
117
+ def test_LogSigmoid(self):
118
+ LogSigmoid_layer = brainstate.nn.LogSigmoid()
119
+ input = brainstate.random.randn(2)
120
+ output = LogSigmoid_layer(input)
121
+
122
+ def test_Softplus(self):
123
+ Softplus_layer = brainstate.nn.Softplus()
124
+ input = brainstate.random.randn(2)
125
+ output = Softplus_layer(input)
126
+
127
+ def test_Softshrink(self):
128
+ Softshrink_layer = brainstate.nn.Softshrink(lambd=1)
129
+ input = brainstate.random.randn(2)
130
+ output = Softshrink_layer(input)
131
+
132
+ def test_PReLU(self):
133
+ PReLU_layer = brainstate.nn.PReLU(num_parameters=2, init=0.5)
134
+ input = brainstate.random.randn(2)
135
+ output = PReLU_layer(input)
136
+
137
+ def test_Softsign(self):
138
+ Softsign_layer = brainstate.nn.Softsign()
139
+ input = brainstate.random.randn(2)
140
+ output = Softsign_layer(input)
141
+
142
+ def test_Tanhshrink(self):
143
+ Tanhshrink_layer = brainstate.nn.Tanhshrink()
144
+ input = brainstate.random.randn(2)
145
+ output = Tanhshrink_layer(input)
146
+
147
+ def test_Softmin(self):
148
+ Softmin_layer = brainstate.nn.Softmin(dim=2)
149
+ input = brainstate.random.randn(2, 3, 4)
150
+ output = Softmin_layer(input)
151
+
152
+ def test_Softmax(self):
153
+ Softmax_layer = brainstate.nn.Softmax(dim=2)
154
+ input = brainstate.random.randn(2, 3, 4)
155
+ output = Softmax_layer(input)
156
+
157
+ def test_Softmax2d(self):
158
+ Softmax2d_layer = brainstate.nn.Softmax2d()
159
+ input = brainstate.random.randn(2, 3, 12, 13)
160
+ output = Softmax2d_layer(input)
161
+
162
+ def test_LogSoftmax(self):
163
+ LogSoftmax_layer = brainstate.nn.LogSoftmax(dim=2)
164
+ input = brainstate.random.randn(2, 3, 4)
165
+ output = LogSoftmax_layer(input)
166
+
167
+
168
+ if __name__ == '__main__':
169
+ absltest.main()
@@ -17,8 +17,8 @@ from typing import Optional, Callable, Union
17
17
 
18
18
  from brainstate import init
19
19
  from brainstate._state import ParamState
20
- from brainstate.nn._module import Module
21
20
  from brainstate.typing import ArrayLike
21
+ from ._module import Module
22
22
 
23
23
  __all__ = [
24
24
  'Embedding',
@@ -13,13 +13,12 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
20
19
  import brainunit as u
21
20
 
22
- import brainstate as bst
21
+ import brainstate
23
22
 
24
23
 
25
24
  class TestExpEuler(unittest.TestCase):
@@ -27,10 +26,10 @@ class TestExpEuler(unittest.TestCase):
27
26
  def fun(x, tau):
28
27
  return -x / tau
29
28
 
30
- with bst.environ.context(dt=0.1):
29
+ with brainstate.environ.context(dt=0.1):
31
30
  with self.assertRaises(AssertionError):
32
- r = bst.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
31
+ r = brainstate.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
33
32
 
34
- with bst.environ.context(dt=1. * u.ms):
35
- r = bst.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
33
+ with brainstate.environ.context(dt=1. * u.ms):
34
+ r = brainstate.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
36
35
  print(r)
@@ -25,8 +25,8 @@ from brainstate import random, augment, environ, init
25
25
  from brainstate._compatible_import import brainevent
26
26
  from brainstate._state import ParamState
27
27
  from brainstate.compile import for_loop
28
- from brainstate.nn._module import Module
29
28
  from brainstate.typing import Size, ArrayLike
29
+ from ._module import Module
30
30
 
31
31
  __all__ = [
32
32
  'EventFixedNumConn',
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import jax.numpy
19
18
  import jax.numpy as jnp
@@ -20,12 +20,11 @@ import jax
20
20
  import numpy as np
21
21
 
22
22
  from brainstate import environ, init, random
23
- from brainstate._state import ShortTermState
24
- from brainstate._state import State, maybe_state
23
+ from brainstate._state import ShortTermState, State, maybe_state
25
24
  from brainstate.compile import while_loop
26
- from brainstate.nn._dynamics._dynamics_base import Dynamics, Prefetch
27
- from brainstate.nn._module import Module
28
25
  from brainstate.typing import ArrayLike, Size, DTypeLike
26
+ from ._dynamics import Dynamics, Prefetch
27
+ from ._module import Module
29
28
 
30
29
  __all__ = [
31
30
  'SpikeTime',
@@ -134,7 +133,7 @@ class PoissonSpike(Dynamics):
134
133
  self.freqs = init.param(freqs, self.varshape, allow_none=False)
135
134
 
136
135
  def update(self):
137
- spikes = random.rand(self.varshape) <= (self.freqs * environ.get_dt())
136
+ spikes = random.rand(*self.varshape) <= (self.freqs * environ.get_dt())
138
137
  spikes = u.math.asarray(spikes, dtype=self.spk_type)
139
138
  return spikes
140
139
 
@@ -22,8 +22,8 @@ import jax.numpy as jnp
22
22
 
23
23
  from brainstate import init, functional
24
24
  from brainstate._state import ParamState
25
- from brainstate.nn._module import Module
26
25
  from brainstate.typing import ArrayLike, Size
26
+ from ._module import Module
27
27
 
28
28
  __all__ = [
29
29
  'Linear',
@@ -350,10 +350,7 @@ class OneToOne(Module):
350
350
  self.weight = param_type(param)
351
351
 
352
352
  def update(self, pre_val):
353
- pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
354
- w_val, w_unit = u.get_mantissa(self.weight.value['weight']), u.get_unit(self.weight.value['weight'])
355
- post_val = pre_val * w_val
356
- post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
353
+ post_val = pre_val * self.weight.value['weight']
357
354
  if 'bias' in self.weight.value:
358
355
  post_val = post_val + self.weight.value['bias']
359
356
  return post_val
@@ -21,8 +21,8 @@ import jax
21
21
  from brainstate import init
22
22
  from brainstate._compatible_import import brainevent
23
23
  from brainstate._state import ParamState
24
- from brainstate.nn._module import Module
25
24
  from brainstate.typing import Size, ArrayLike
25
+ from ._module import Module
26
26
 
27
27
  __all__ = [
28
28
  'EventLinear',
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import jax
19
18
  import jax.numpy as jnp