brainstate 0.1.1__py2.py3-none-any.whl → 0.1.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +12 -9
- brainstate/_state.py +1 -1
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +29 -14
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/functional/_activations_test.py +61 -61
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/mixin.py +1 -14
- brainstate/nn/__init__.py +81 -17
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
- brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +31 -33
- brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +6 -2
- brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +1 -1
- brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +9 -9
- brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +139 -18
- brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +14 -15
- brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
- brainstate/nn/_elementwise_test.py +169 -0
- brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob_mv.py} +1 -1
- brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_mv_test.py} +0 -1
- brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
- brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
- brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +1 -1
- brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -1
- brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +15 -17
- brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
- brainstate/nn/_module_test.py +34 -37
- brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
- brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +18 -19
- brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
- brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +10 -12
- brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
- brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +20 -22
- brainstate/nn/{_dynamics/_projection_base.py → _projection.py} +35 -3
- brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
- brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +6 -7
- brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
- brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +9 -10
- brainstate/nn/_stp.py +236 -0
- brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +17 -206
- brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +9 -10
- brainstate/nn/_synaptic_projection.py +133 -0
- brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
- brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +4 -5
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed_test.py +10 -12
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/METADATA +1 -1
- brainstate-0.1.3.dist-info/RECORD +131 -0
- brainstate/nn/_dyn_impl/__init__.py +0 -42
- brainstate/nn/_dynamics/__init__.py +0 -37
- brainstate/nn/_elementwise/__init__.py +0 -22
- brainstate/nn/_elementwise/_elementwise_test.py +0 -171
- brainstate/nn/_interaction/__init__.py +0 -41
- brainstate-0.1.1.dist-info/RECORD +0 -133
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/LICENSE +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/WHEEL +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/top_level.txt +0 -0
@@ -18,19 +18,19 @@ import unittest
|
|
18
18
|
|
19
19
|
import numpy as np
|
20
20
|
|
21
|
-
import brainstate
|
21
|
+
import brainstate
|
22
22
|
|
23
23
|
|
24
24
|
class TestDropout(unittest.TestCase):
|
25
25
|
|
26
26
|
def test_dropout(self):
|
27
27
|
# Create a Dropout layer with a dropout rate of 0.5
|
28
|
-
dropout_layer =
|
28
|
+
dropout_layer = brainstate.nn.Dropout(0.5)
|
29
29
|
|
30
30
|
# Input data
|
31
31
|
input_data = np.arange(20)
|
32
32
|
|
33
|
-
with
|
33
|
+
with brainstate.environ.context(fit=True):
|
34
34
|
# Apply dropout
|
35
35
|
output_data = dropout_layer(input_data)
|
36
36
|
|
@@ -47,10 +47,10 @@ class TestDropout(unittest.TestCase):
|
|
47
47
|
np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements)
|
48
48
|
|
49
49
|
def test_DropoutFixed(self):
|
50
|
-
dropout_layer =
|
50
|
+
dropout_layer = brainstate.nn.DropoutFixed(in_size=(2, 3), prob=0.5)
|
51
51
|
dropout_layer.init_state(batch_size=2)
|
52
52
|
input_data = np.random.randn(2, 2, 3)
|
53
|
-
with
|
53
|
+
with brainstate.environ.context(fit=True):
|
54
54
|
output_data = dropout_layer.update(input_data)
|
55
55
|
self.assertEqual(input_data.shape, output_data.shape)
|
56
56
|
self.assertTrue(np.any(output_data == 0))
|
@@ -72,9 +72,9 @@ class TestDropout(unittest.TestCase):
|
|
72
72
|
# np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
|
73
73
|
|
74
74
|
def test_Dropout2d(self):
|
75
|
-
dropout_layer =
|
75
|
+
dropout_layer = brainstate.nn.Dropout2d(prob=0.5)
|
76
76
|
input_data = np.random.randn(2, 3, 4, 5)
|
77
|
-
with
|
77
|
+
with brainstate.environ.context(fit=True):
|
78
78
|
output_data = dropout_layer(input_data)
|
79
79
|
self.assertEqual(input_data.shape, output_data.shape)
|
80
80
|
self.assertTrue(np.any(output_data == 0))
|
@@ -84,9 +84,9 @@ class TestDropout(unittest.TestCase):
|
|
84
84
|
np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
|
85
85
|
|
86
86
|
def test_Dropout3d(self):
|
87
|
-
dropout_layer =
|
87
|
+
dropout_layer = brainstate.nn.Dropout3d(prob=0.5)
|
88
88
|
input_data = np.random.randn(2, 3, 4, 5, 6)
|
89
|
-
with
|
89
|
+
with brainstate.environ.context(fit=True):
|
90
90
|
output_data = dropout_layer(input_data)
|
91
91
|
self.assertEqual(input_data.shape, output_data.shape)
|
92
92
|
self.assertTrue(np.any(output_data == 0))
|
@@ -41,13 +41,14 @@ import numpy as np
|
|
41
41
|
from brainstate import environ
|
42
42
|
from brainstate._state import State
|
43
43
|
from brainstate.graph import Node
|
44
|
-
from brainstate.mixin import ParamDescriber
|
45
|
-
from brainstate.nn._module import Module
|
44
|
+
from brainstate.mixin import ParamDescriber, UpdateReturn
|
46
45
|
from brainstate.typing import Size, ArrayLike
|
47
|
-
from .
|
46
|
+
from ._delay import StateWithDelay, Delay
|
47
|
+
from ._module import Module
|
48
48
|
|
49
49
|
__all__ = [
|
50
|
-
'DynamicsGroup', 'Projection', 'Dynamics',
|
50
|
+
'DynamicsGroup', 'Projection', 'Dynamics',
|
51
|
+
'Prefetch', 'PrefetchDelay', 'PrefetchDelayAt', 'OutputDelayAt',
|
51
52
|
]
|
52
53
|
|
53
54
|
T = TypeVar('T')
|
@@ -99,7 +100,7 @@ class Projection(Module):
|
|
99
100
|
raise ValueError('Do not implement the update() function.')
|
100
101
|
|
101
102
|
|
102
|
-
class Dynamics(Module):
|
103
|
+
class Dynamics(Module, UpdateReturn):
|
103
104
|
"""
|
104
105
|
Base class for implementing neural dynamics models in BrainState.
|
105
106
|
|
@@ -821,6 +822,41 @@ class Dynamics(Module):
|
|
821
822
|
else:
|
822
823
|
raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
|
823
824
|
|
825
|
+
def prefetch_delay(self, state: str, delay: Optional[ArrayLike] = None) -> 'PrefetchDelayAt':
|
826
|
+
"""
|
827
|
+
Create a reference to a delayed state or variable in the module.
|
828
|
+
|
829
|
+
This method simplifies the process of accessing a delayed version of a state or variable
|
830
|
+
within the module. It first creates a prefetch reference to the specified state,
|
831
|
+
then specifies the delay time for accessing this state.
|
832
|
+
|
833
|
+
Args:
|
834
|
+
state (str): The name of the state or variable to reference.
|
835
|
+
delay (Optional[ArrayLike]): The amount of time to delay the variable access,
|
836
|
+
typically in time units (e.g., milliseconds). Defaults to None.
|
837
|
+
|
838
|
+
Returns:
|
839
|
+
PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
|
840
|
+
"""
|
841
|
+
return self.prefetch(state).delay.at(delay)
|
842
|
+
|
843
|
+
def output_delay(self, delay: Optional[ArrayLike] = None) -> 'OutputDelayAt':
|
844
|
+
"""
|
845
|
+
Create a reference to the delayed output of the module.
|
846
|
+
|
847
|
+
This method simplifies the process of accessing a delayed version of the module's output.
|
848
|
+
It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
|
849
|
+
at the specified delay time.
|
850
|
+
|
851
|
+
Args:
|
852
|
+
delay (Optional[ArrayLike]): The amount of time to delay the output access,
|
853
|
+
typically in time units (e.g., milliseconds). Defaults to None.
|
854
|
+
|
855
|
+
Returns:
|
856
|
+
OutputDelayAt: An object that provides access to the module's output at the specified delay time.
|
857
|
+
"""
|
858
|
+
return OutputDelayAt(self, delay)
|
859
|
+
|
824
860
|
|
825
861
|
class Prefetch(Node):
|
826
862
|
"""
|
@@ -885,6 +921,7 @@ class Prefetch(Node):
|
|
885
921
|
An object that provides access to delayed versions of the prefetched item.
|
886
922
|
"""
|
887
923
|
return PrefetchDelay(self.module, self.item)
|
924
|
+
# return PrefetchDelayAt(self.module, self.item, time)
|
888
925
|
|
889
926
|
def __call__(self, *args, **kwargs):
|
890
927
|
"""
|
@@ -1007,7 +1044,7 @@ class PrefetchDelayAt(Node):
|
|
1007
1044
|
self,
|
1008
1045
|
module: Dynamics,
|
1009
1046
|
item: str,
|
1010
|
-
time: ArrayLike
|
1047
|
+
time: ArrayLike = None,
|
1011
1048
|
):
|
1012
1049
|
"""
|
1013
1050
|
Initialize a PrefetchDelayAt object.
|
@@ -1026,14 +1063,16 @@ class PrefetchDelayAt(Node):
|
|
1026
1063
|
self.module = module
|
1027
1064
|
self.item = item
|
1028
1065
|
self.time = time
|
1029
|
-
self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
|
1030
1066
|
|
1031
|
-
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1067
|
+
if time is not None:
|
1068
|
+
self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
|
1069
|
+
|
1070
|
+
# register the delay
|
1071
|
+
key = _get_prefetch_delay_key(item)
|
1072
|
+
if not module._has_after_update(key):
|
1073
|
+
module._add_after_update(key, not_receive_update_output(StateWithDelay(module, item)))
|
1074
|
+
self.state_delay: StateWithDelay = module._get_after_update(key)
|
1075
|
+
self.state_delay.register_delay(time)
|
1037
1076
|
|
1038
1077
|
def __call__(self, *args, **kwargs):
|
1039
1078
|
"""
|
@@ -1044,12 +1083,94 @@ class PrefetchDelayAt(Node):
|
|
1044
1083
|
Any
|
1045
1084
|
The value of the state or variable at the specified delay time.
|
1046
1085
|
"""
|
1047
|
-
|
1048
|
-
|
1086
|
+
if self.time is None:
|
1087
|
+
return _get_prefetch_item(self).value
|
1088
|
+
else:
|
1089
|
+
return self.state_delay.retrieve_at_step(self.step)
|
1090
|
+
|
1091
|
+
|
1092
|
+
class OutputDelayAt(Node):
|
1093
|
+
"""
|
1094
|
+
Provides access to a specific delayed state or variable value at the specific time.
|
1095
|
+
|
1096
|
+
This class represents the final step in the prefetch delay chain, providing
|
1097
|
+
actual access to state values at a specific delay time. It converts the
|
1098
|
+
specified time delay into steps and registers the delay with the appropriate
|
1099
|
+
StateWithDelay handler.
|
1100
|
+
|
1101
|
+
Parameters
|
1102
|
+
----------
|
1103
|
+
module : Dynamics
|
1104
|
+
The dynamics module that contains the referenced state or variable.
|
1105
|
+
item : str
|
1106
|
+
The name of the state or variable to access with delay.
|
1107
|
+
time : ArrayLike
|
1108
|
+
The amount of time to delay access by, typically in time units (e.g., milliseconds).
|
1109
|
+
|
1110
|
+
Examples
|
1111
|
+
--------
|
1112
|
+
>>> import brainstate
|
1113
|
+
>>> import brainunit as u
|
1114
|
+
>>> neuron = brainstate.nn.LIF(10)
|
1115
|
+
>>> # Create a reference to voltage delayed by 5ms
|
1116
|
+
>>> delayed_v = PrefetchDelayAt(neuron, 'V', 5.0 * u.ms)
|
1117
|
+
>>> # Get the delayed value
|
1118
|
+
>>> v_value = delayed_v()
|
1119
|
+
"""
|
1120
|
+
|
1121
|
+
def __init__(
|
1122
|
+
self,
|
1123
|
+
module: Dynamics,
|
1124
|
+
time: Optional[ArrayLike] = None,
|
1125
|
+
):
|
1126
|
+
"""
|
1127
|
+
Initialize a PrefetchDelayAt object.
|
1128
|
+
|
1129
|
+
Parameters
|
1130
|
+
----------
|
1131
|
+
module : AlignPre, Module
|
1132
|
+
The dynamics module that contains the referenced state or variable.
|
1133
|
+
time : ArrayLike
|
1134
|
+
The amount of time to delay access by, typically in time units.
|
1135
|
+
"""
|
1136
|
+
super().__init__()
|
1137
|
+
assert isinstance(module, UpdateReturn), 'The module should implement the `update_return` method.'
|
1138
|
+
assert isinstance(module, Module), 'The module should be an instance of Module.'
|
1139
|
+
self.module = module
|
1140
|
+
self.time = time
|
1141
|
+
if time is not None:
|
1142
|
+
self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
|
1143
|
+
|
1144
|
+
# register the delay
|
1145
|
+
key = _get_output_delay_key()
|
1146
|
+
if not module._has_after_update(key):
|
1147
|
+
# TODO: unit processing
|
1148
|
+
delay = Delay(module.update_return(), time)
|
1149
|
+
module._add_after_update(key, receive_update_output(delay))
|
1150
|
+
self.out_delay: Delay = module._get_after_update(key)
|
1151
|
+
self.out_delay.register_delay(time)
|
1152
|
+
|
1153
|
+
def __call__(self, *args, **kwargs):
|
1154
|
+
"""
|
1155
|
+
Retrieve the value of the state at the specified delay time.
|
1156
|
+
|
1157
|
+
Returns
|
1158
|
+
-------
|
1159
|
+
Any
|
1160
|
+
The value of the state or variable at the specified delay time.
|
1161
|
+
"""
|
1162
|
+
if self.time is None:
|
1163
|
+
return self.module.update_return()
|
1164
|
+
else:
|
1165
|
+
return self.out_delay.retrieve_at_step(self.step)
|
1166
|
+
|
1167
|
+
|
1168
|
+
def _get_prefetch_delay_key(item) -> str:
|
1169
|
+
return f'{item}-prefetch-delay'
|
1049
1170
|
|
1050
1171
|
|
1051
|
-
def
|
1052
|
-
return f'
|
1172
|
+
def _get_output_delay_key() -> str:
|
1173
|
+
return f'output-delay'
|
1053
1174
|
|
1054
1175
|
|
1055
1176
|
def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
|
@@ -1064,7 +1185,7 @@ def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDela
|
|
1064
1185
|
f'The target module should be an instance '
|
1065
1186
|
f'of Dynamics. But got {target.module}.'
|
1066
1187
|
)
|
1067
|
-
delay = target.module._get_after_update(
|
1188
|
+
delay = target.module._get_after_update(_get_prefetch_delay_key(target.item))
|
1068
1189
|
if not isinstance(delay, StateWithDelay):
|
1069
1190
|
raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
|
1070
1191
|
f'its delay. But got {delay}.')
|
@@ -15,47 +15,46 @@
|
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
|
-
from __future__ import annotations
|
19
18
|
|
20
19
|
import unittest
|
21
20
|
|
22
21
|
import numpy as np
|
23
22
|
|
24
|
-
import brainstate
|
23
|
+
import brainstate
|
25
24
|
|
26
25
|
|
27
26
|
class TestModuleGroup(unittest.TestCase):
|
28
27
|
def test_initialization(self):
|
29
|
-
group =
|
30
|
-
self.assertIsInstance(group,
|
28
|
+
group = brainstate.nn.DynamicsGroup()
|
29
|
+
self.assertIsInstance(group, brainstate.nn.DynamicsGroup)
|
31
30
|
|
32
31
|
|
33
32
|
class TestProjection(unittest.TestCase):
|
34
33
|
def test_initialization(self):
|
35
|
-
proj =
|
36
|
-
self.assertIsInstance(proj,
|
34
|
+
proj = brainstate.nn.Projection()
|
35
|
+
self.assertIsInstance(proj, brainstate.nn.Projection)
|
37
36
|
|
38
37
|
def test_update_not_implemented(self):
|
39
|
-
proj =
|
38
|
+
proj = brainstate.nn.Projection()
|
40
39
|
with self.assertRaises(ValueError):
|
41
40
|
proj.update()
|
42
41
|
|
43
42
|
|
44
43
|
class TestDynamics(unittest.TestCase):
|
45
44
|
def test_initialization(self):
|
46
|
-
dyn =
|
47
|
-
self.assertIsInstance(dyn,
|
45
|
+
dyn = brainstate.nn.Dynamics(in_size=10)
|
46
|
+
self.assertIsInstance(dyn, brainstate.nn.Dynamics)
|
48
47
|
self.assertEqual(dyn.in_size, (10,))
|
49
48
|
self.assertEqual(dyn.out_size, (10,))
|
50
49
|
|
51
50
|
def test_size_validation(self):
|
52
51
|
with self.assertRaises(ValueError):
|
53
|
-
|
52
|
+
brainstate.nn.Dynamics(in_size=[])
|
54
53
|
with self.assertRaises(ValueError):
|
55
|
-
|
54
|
+
brainstate.nn.Dynamics(in_size="invalid")
|
56
55
|
|
57
56
|
def test_input_handling(self):
|
58
|
-
dyn =
|
57
|
+
dyn = brainstate.nn.Dynamics(in_size=10)
|
59
58
|
dyn.add_current_input("test_current", lambda: np.random.rand(10))
|
60
59
|
dyn.add_delta_input("test_delta", lambda: np.random.rand(10))
|
61
60
|
|
@@ -63,15 +62,15 @@ class TestDynamics(unittest.TestCase):
|
|
63
62
|
self.assertIn("test_delta", dyn.delta_inputs)
|
64
63
|
|
65
64
|
def test_duplicate_input_key(self):
|
66
|
-
dyn =
|
65
|
+
dyn = brainstate.nn.Dynamics(in_size=10)
|
67
66
|
dyn.add_current_input("test", lambda: np.random.rand(10))
|
68
67
|
with self.assertRaises(ValueError):
|
69
68
|
dyn.add_current_input("test", lambda: np.random.rand(10))
|
70
69
|
|
71
70
|
def test_varshape(self):
|
72
|
-
dyn =
|
71
|
+
dyn = brainstate.nn.Dynamics(in_size=(2, 3))
|
73
72
|
self.assertEqual(dyn.varshape, (2, 3))
|
74
|
-
dyn =
|
73
|
+
dyn = brainstate.nn.Dynamics(in_size=(2, 3))
|
75
74
|
self.assertEqual(dyn.varshape, (2, 3))
|
76
75
|
|
77
76
|
|
@@ -23,8 +23,8 @@ import jax.typing
|
|
23
23
|
|
24
24
|
from brainstate import random, functional as F
|
25
25
|
from brainstate._state import ParamState
|
26
|
-
from brainstate.nn._module import ElementWiseBlock
|
27
26
|
from brainstate.typing import ArrayLike
|
27
|
+
from ._module import ElementWiseBlock
|
28
28
|
|
29
29
|
__all__ = [
|
30
30
|
# activation functions
|
@@ -0,0 +1,169 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from absl.testing import absltest
|
17
|
+
from absl.testing import parameterized
|
18
|
+
|
19
|
+
import brainstate
|
20
|
+
|
21
|
+
|
22
|
+
class Test_Activation(parameterized.TestCase):
|
23
|
+
|
24
|
+
def test_Threshold(self):
|
25
|
+
threshold_layer = brainstate.nn.Threshold(5, 20)
|
26
|
+
input = brainstate.random.randn(2)
|
27
|
+
output = threshold_layer(input)
|
28
|
+
|
29
|
+
def test_ReLU(self):
|
30
|
+
ReLU_layer = brainstate.nn.ReLU()
|
31
|
+
input = brainstate.random.randn(2)
|
32
|
+
output = ReLU_layer(input)
|
33
|
+
|
34
|
+
def test_RReLU(self):
|
35
|
+
RReLU_layer = brainstate.nn.RReLU(lower=0, upper=1)
|
36
|
+
input = brainstate.random.randn(2)
|
37
|
+
output = RReLU_layer(input)
|
38
|
+
|
39
|
+
def test_Hardtanh(self):
|
40
|
+
Hardtanh_layer = brainstate.nn.Hardtanh(min_val=0, max_val=1, )
|
41
|
+
input = brainstate.random.randn(2)
|
42
|
+
output = Hardtanh_layer(input)
|
43
|
+
|
44
|
+
def test_ReLU6(self):
|
45
|
+
ReLU6_layer = brainstate.nn.ReLU6()
|
46
|
+
input = brainstate.random.randn(2)
|
47
|
+
output = ReLU6_layer(input)
|
48
|
+
|
49
|
+
def test_Sigmoid(self):
|
50
|
+
Sigmoid_layer = brainstate.nn.Sigmoid()
|
51
|
+
input = brainstate.random.randn(2)
|
52
|
+
output = Sigmoid_layer(input)
|
53
|
+
|
54
|
+
def test_Hardsigmoid(self):
|
55
|
+
Hardsigmoid_layer = brainstate.nn.Hardsigmoid()
|
56
|
+
input = brainstate.random.randn(2)
|
57
|
+
output = Hardsigmoid_layer(input)
|
58
|
+
|
59
|
+
def test_Tanh(self):
|
60
|
+
Tanh_layer = brainstate.nn.Tanh()
|
61
|
+
input = brainstate.random.randn(2)
|
62
|
+
output = Tanh_layer(input)
|
63
|
+
|
64
|
+
def test_SiLU(self):
|
65
|
+
SiLU_layer = brainstate.nn.SiLU()
|
66
|
+
input = brainstate.random.randn(2)
|
67
|
+
output = SiLU_layer(input)
|
68
|
+
|
69
|
+
def test_Mish(self):
|
70
|
+
Mish_layer = brainstate.nn.Mish()
|
71
|
+
input = brainstate.random.randn(2)
|
72
|
+
output = Mish_layer(input)
|
73
|
+
|
74
|
+
def test_Hardswish(self):
|
75
|
+
Hardswish_layer = brainstate.nn.Hardswish()
|
76
|
+
input = brainstate.random.randn(2)
|
77
|
+
output = Hardswish_layer(input)
|
78
|
+
|
79
|
+
def test_ELU(self):
|
80
|
+
ELU_layer = brainstate.nn.ELU(alpha=0.5, )
|
81
|
+
input = brainstate.random.randn(2)
|
82
|
+
output = ELU_layer(input)
|
83
|
+
|
84
|
+
def test_CELU(self):
|
85
|
+
CELU_layer = brainstate.nn.CELU(alpha=0.5, )
|
86
|
+
input = brainstate.random.randn(2)
|
87
|
+
output = CELU_layer(input)
|
88
|
+
|
89
|
+
def test_SELU(self):
|
90
|
+
SELU_layer = brainstate.nn.SELU()
|
91
|
+
input = brainstate.random.randn(2)
|
92
|
+
output = SELU_layer(input)
|
93
|
+
|
94
|
+
def test_GLU(self):
|
95
|
+
GLU_layer = brainstate.nn.GLU()
|
96
|
+
input = brainstate.random.randn(4, 2)
|
97
|
+
output = GLU_layer(input)
|
98
|
+
|
99
|
+
@parameterized.product(
|
100
|
+
approximate=['tanh', 'none']
|
101
|
+
)
|
102
|
+
def test_GELU(self, approximate):
|
103
|
+
GELU_layer = brainstate.nn.GELU()
|
104
|
+
input = brainstate.random.randn(2)
|
105
|
+
output = GELU_layer(input)
|
106
|
+
|
107
|
+
def test_Hardshrink(self):
|
108
|
+
Hardshrink_layer = brainstate.nn.Hardshrink(lambd=1)
|
109
|
+
input = brainstate.random.randn(2)
|
110
|
+
output = Hardshrink_layer(input)
|
111
|
+
|
112
|
+
def test_LeakyReLU(self):
|
113
|
+
LeakyReLU_layer = brainstate.nn.LeakyReLU()
|
114
|
+
input = brainstate.random.randn(2)
|
115
|
+
output = LeakyReLU_layer(input)
|
116
|
+
|
117
|
+
def test_LogSigmoid(self):
|
118
|
+
LogSigmoid_layer = brainstate.nn.LogSigmoid()
|
119
|
+
input = brainstate.random.randn(2)
|
120
|
+
output = LogSigmoid_layer(input)
|
121
|
+
|
122
|
+
def test_Softplus(self):
|
123
|
+
Softplus_layer = brainstate.nn.Softplus()
|
124
|
+
input = brainstate.random.randn(2)
|
125
|
+
output = Softplus_layer(input)
|
126
|
+
|
127
|
+
def test_Softshrink(self):
|
128
|
+
Softshrink_layer = brainstate.nn.Softshrink(lambd=1)
|
129
|
+
input = brainstate.random.randn(2)
|
130
|
+
output = Softshrink_layer(input)
|
131
|
+
|
132
|
+
def test_PReLU(self):
|
133
|
+
PReLU_layer = brainstate.nn.PReLU(num_parameters=2, init=0.5)
|
134
|
+
input = brainstate.random.randn(2)
|
135
|
+
output = PReLU_layer(input)
|
136
|
+
|
137
|
+
def test_Softsign(self):
|
138
|
+
Softsign_layer = brainstate.nn.Softsign()
|
139
|
+
input = brainstate.random.randn(2)
|
140
|
+
output = Softsign_layer(input)
|
141
|
+
|
142
|
+
def test_Tanhshrink(self):
|
143
|
+
Tanhshrink_layer = brainstate.nn.Tanhshrink()
|
144
|
+
input = brainstate.random.randn(2)
|
145
|
+
output = Tanhshrink_layer(input)
|
146
|
+
|
147
|
+
def test_Softmin(self):
|
148
|
+
Softmin_layer = brainstate.nn.Softmin(dim=2)
|
149
|
+
input = brainstate.random.randn(2, 3, 4)
|
150
|
+
output = Softmin_layer(input)
|
151
|
+
|
152
|
+
def test_Softmax(self):
|
153
|
+
Softmax_layer = brainstate.nn.Softmax(dim=2)
|
154
|
+
input = brainstate.random.randn(2, 3, 4)
|
155
|
+
output = Softmax_layer(input)
|
156
|
+
|
157
|
+
def test_Softmax2d(self):
|
158
|
+
Softmax2d_layer = brainstate.nn.Softmax2d()
|
159
|
+
input = brainstate.random.randn(2, 3, 12, 13)
|
160
|
+
output = Softmax2d_layer(input)
|
161
|
+
|
162
|
+
def test_LogSoftmax(self):
|
163
|
+
LogSoftmax_layer = brainstate.nn.LogSoftmax(dim=2)
|
164
|
+
input = brainstate.random.randn(2, 3, 4)
|
165
|
+
output = LogSoftmax_layer(input)
|
166
|
+
|
167
|
+
|
168
|
+
if __name__ == '__main__':
|
169
|
+
absltest.main()
|
@@ -17,8 +17,8 @@ from typing import Optional, Callable, Union
|
|
17
17
|
|
18
18
|
from brainstate import init
|
19
19
|
from brainstate._state import ParamState
|
20
|
-
from brainstate.nn._module import Module
|
21
20
|
from brainstate.typing import ArrayLike
|
21
|
+
from ._module import Module
|
22
22
|
|
23
23
|
__all__ = [
|
24
24
|
'Embedding',
|
brainstate/nn/_exp_euler_test.py
CHANGED
@@ -13,13 +13,12 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
16
|
|
18
17
|
import unittest
|
19
18
|
|
20
19
|
import brainunit as u
|
21
20
|
|
22
|
-
import brainstate
|
21
|
+
import brainstate
|
23
22
|
|
24
23
|
|
25
24
|
class TestExpEuler(unittest.TestCase):
|
@@ -27,10 +26,10 @@ class TestExpEuler(unittest.TestCase):
|
|
27
26
|
def fun(x, tau):
|
28
27
|
return -x / tau
|
29
28
|
|
30
|
-
with
|
29
|
+
with brainstate.environ.context(dt=0.1):
|
31
30
|
with self.assertRaises(AssertionError):
|
32
|
-
r =
|
31
|
+
r = brainstate.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
|
33
32
|
|
34
|
-
with
|
35
|
-
r =
|
33
|
+
with brainstate.environ.context(dt=1. * u.ms):
|
34
|
+
r = brainstate.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
|
36
35
|
print(r)
|
@@ -25,8 +25,8 @@ from brainstate import random, augment, environ, init
|
|
25
25
|
from brainstate._compatible_import import brainevent
|
26
26
|
from brainstate._state import ParamState
|
27
27
|
from brainstate.compile import for_loop
|
28
|
-
from brainstate.nn._module import Module
|
29
28
|
from brainstate.typing import Size, ArrayLike
|
29
|
+
from ._module import Module
|
30
30
|
|
31
31
|
__all__ = [
|
32
32
|
'EventFixedNumConn',
|
@@ -20,12 +20,11 @@ import jax
|
|
20
20
|
import numpy as np
|
21
21
|
|
22
22
|
from brainstate import environ, init, random
|
23
|
-
from brainstate._state import ShortTermState
|
24
|
-
from brainstate._state import State, maybe_state
|
23
|
+
from brainstate._state import ShortTermState, State, maybe_state
|
25
24
|
from brainstate.compile import while_loop
|
26
|
-
from brainstate.nn._dynamics._dynamics_base import Dynamics, Prefetch
|
27
|
-
from brainstate.nn._module import Module
|
28
25
|
from brainstate.typing import ArrayLike, Size, DTypeLike
|
26
|
+
from ._dynamics import Dynamics, Prefetch
|
27
|
+
from ._module import Module
|
29
28
|
|
30
29
|
__all__ = [
|
31
30
|
'SpikeTime',
|
@@ -134,7 +133,7 @@ class PoissonSpike(Dynamics):
|
|
134
133
|
self.freqs = init.param(freqs, self.varshape, allow_none=False)
|
135
134
|
|
136
135
|
def update(self):
|
137
|
-
spikes = random.rand(self.varshape) <= (self.freqs * environ.get_dt())
|
136
|
+
spikes = random.rand(*self.varshape) <= (self.freqs * environ.get_dt())
|
138
137
|
spikes = u.math.asarray(spikes, dtype=self.spk_type)
|
139
138
|
return spikes
|
140
139
|
|
@@ -22,8 +22,8 @@ import jax.numpy as jnp
|
|
22
22
|
|
23
23
|
from brainstate import init, functional
|
24
24
|
from brainstate._state import ParamState
|
25
|
-
from brainstate.nn._module import Module
|
26
25
|
from brainstate.typing import ArrayLike, Size
|
26
|
+
from ._module import Module
|
27
27
|
|
28
28
|
__all__ = [
|
29
29
|
'Linear',
|
@@ -350,10 +350,7 @@ class OneToOne(Module):
|
|
350
350
|
self.weight = param_type(param)
|
351
351
|
|
352
352
|
def update(self, pre_val):
|
353
|
-
|
354
|
-
w_val, w_unit = u.get_mantissa(self.weight.value['weight']), u.get_unit(self.weight.value['weight'])
|
355
|
-
post_val = pre_val * w_val
|
356
|
-
post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
|
353
|
+
post_val = pre_val * self.weight.value['weight']
|
357
354
|
if 'bias' in self.weight.value:
|
358
355
|
post_val = post_val + self.weight.value['bias']
|
359
356
|
return post_val
|
@@ -21,8 +21,8 @@ import jax
|
|
21
21
|
from brainstate import init
|
22
22
|
from brainstate._compatible_import import brainevent
|
23
23
|
from brainstate._state import ParamState
|
24
|
-
from brainstate.nn._module import Module
|
25
24
|
from brainstate.typing import Size, ArrayLike
|
25
|
+
from ._module import Module
|
26
26
|
|
27
27
|
__all__ = [
|
28
28
|
'EventLinear',
|