brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +2 -4
- brainstate/_deprecation_test.py +2 -24
- brainstate/_state.py +540 -35
- brainstate/_state_test.py +1085 -8
- brainstate/graph/_operation.py +1 -5
- brainstate/mixin.py +14 -0
- brainstate/nn/__init__.py +42 -33
- brainstate/nn/_collective_ops.py +2 -0
- brainstate/nn/_common_test.py +0 -20
- brainstate/nn/_delay.py +1 -1
- brainstate/nn/_dropout_test.py +9 -6
- brainstate/nn/_dynamics.py +67 -464
- brainstate/nn/_dynamics_test.py +0 -14
- brainstate/nn/_embedding.py +7 -7
- brainstate/nn/_exp_euler.py +9 -9
- brainstate/nn/_linear.py +21 -21
- brainstate/nn/_module.py +25 -18
- brainstate/nn/_normalizations.py +27 -27
- brainstate/random/__init__.py +6 -6
- brainstate/random/{_rand_funs.py → _fun.py} +1 -1
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +1 -1
- brainstate/random/{_rand_state.py → _state.py} +121 -418
- brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
- brainstate/transform/__init__.py +6 -9
- brainstate/transform/_conditions.py +2 -2
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_make_jaxpr.py +221 -61
- brainstate/transform/_make_jaxpr_test.py +125 -1
- brainstate/transform/_mapping.py +287 -209
- brainstate/transform/_mapping_test.py +94 -184
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/graph/_operation.py
CHANGED
@@ -1262,11 +1262,7 @@ def treefy_states(
|
|
1262
1262
|
if len(filters) == 0:
|
1263
1263
|
return state_mapping
|
1264
1264
|
else:
|
1265
|
-
|
1266
|
-
if len(filters) == 1:
|
1267
|
-
return state_mappings[0]
|
1268
|
-
else:
|
1269
|
-
return state_mappings
|
1265
|
+
return state_mapping.filter(*filters)
|
1270
1266
|
|
1271
1267
|
|
1272
1268
|
def _graph_update_dynamic(node: Any, state: Mapping) -> None:
|
brainstate/mixin.py
CHANGED
@@ -380,6 +380,20 @@ class ParamDescriber(metaclass=NoSubclassMeta):
|
|
380
380
|
merged_kwargs = {**self.kwargs, **kwargs}
|
381
381
|
return self.cls(*self.args, *args, **merged_kwargs)
|
382
382
|
|
383
|
+
def __repr__(self):
|
384
|
+
"""
|
385
|
+
Return a string representation of the ParamDescriber.
|
386
|
+
|
387
|
+
Returns
|
388
|
+
-------
|
389
|
+
str
|
390
|
+
A string showing the class and stored parameters.
|
391
|
+
"""
|
392
|
+
args_str = ', '.join(repr(a) for a in self.args)
|
393
|
+
kwargs_str = ', '.join(f'{k}={v!r}' for k, v in self.kwargs.items())
|
394
|
+
all_params = ', '.join(filter(None, [args_str, kwargs_str]))
|
395
|
+
return f'ParamDescriber({self.cls.__name__}({all_params}))'
|
396
|
+
|
383
397
|
def init(self, *args, **kwargs):
|
384
398
|
"""
|
385
399
|
Alias for __call__, explicitly named for clarity.
|
brainstate/nn/__init__.py
CHANGED
@@ -87,43 +87,52 @@ del (
|
|
87
87
|
|
88
88
|
# Deprecated names that redirect to brainpy
|
89
89
|
_DEPRECATED_NAMES = {
|
90
|
-
'SpikeTime': 'brainpy.SpikeTime',
|
91
|
-
'PoissonSpike': 'brainpy.PoissonSpike',
|
92
|
-
'PoissonEncoder': 'brainpy.PoissonEncoder',
|
93
|
-
'PoissonInput': 'brainpy.PoissonInput',
|
94
|
-
'poisson_input': 'brainpy.poisson_input',
|
95
|
-
'Neuron': 'brainpy.Neuron',
|
96
|
-
'IF': 'brainpy.IF',
|
97
|
-
'LIF': 'brainpy.LIF',
|
98
|
-
'LIFRef': 'brainpy.LIFRef',
|
99
|
-
'ALIF': 'brainpy.ALIF',
|
100
|
-
'LeakyRateReadout': 'brainpy.LeakyRateReadout',
|
101
|
-
'LeakySpikeReadout': 'brainpy.LeakySpikeReadout',
|
102
|
-
'STP': 'brainpy.STP',
|
103
|
-
'STD': 'brainpy.STD',
|
104
|
-
'Synapse': 'brainpy.Synapse',
|
105
|
-
'Expon': 'brainpy.Expon',
|
106
|
-
'DualExpon': 'brainpy.DualExpon',
|
107
|
-
'Alpha': 'brainpy.Alpha',
|
108
|
-
'AMPA': 'brainpy.AMPA',
|
109
|
-
'GABAa': 'brainpy.GABAa',
|
110
|
-
'COBA': 'brainpy.COBA',
|
111
|
-
'CUBA': 'brainpy.CUBA',
|
112
|
-
'MgBlock': 'brainpy.MgBlock',
|
113
|
-
'SynOut': 'brainpy.SynOut',
|
114
|
-
'AlignPostProj': 'brainpy.AlignPostProj',
|
115
|
-
'DeltaProj': 'brainpy.DeltaProj',
|
116
|
-
'CurrentProj': 'brainpy.CurrentProj',
|
117
|
-
'align_pre_projection': 'brainpy.align_pre_projection',
|
118
|
-
'Projection': 'brainpy.Projection',
|
119
|
-
'SymmetryGapJunction': 'brainpy.SymmetryGapJunction',
|
120
|
-
'AsymmetryGapJunction': 'brainpy.AsymmetryGapJunction',
|
90
|
+
'SpikeTime': 'brainpy.state.SpikeTime',
|
91
|
+
'PoissonSpike': 'brainpy.state.PoissonSpike',
|
92
|
+
'PoissonEncoder': 'brainpy.state.PoissonEncoder',
|
93
|
+
'PoissonInput': 'brainpy.state.PoissonInput',
|
94
|
+
'poisson_input': 'brainpy.state.poisson_input',
|
95
|
+
'Neuron': 'brainpy.state.Neuron',
|
96
|
+
'IF': 'brainpy.state.IF',
|
97
|
+
'LIF': 'brainpy.state.LIF',
|
98
|
+
'LIFRef': 'brainpy.state.LIFRef',
|
99
|
+
'ALIF': 'brainpy.state.ALIF',
|
100
|
+
'LeakyRateReadout': 'brainpy.state.LeakyRateReadout',
|
101
|
+
'LeakySpikeReadout': 'brainpy.state.LeakySpikeReadout',
|
102
|
+
'STP': 'brainpy.state.STP',
|
103
|
+
'STD': 'brainpy.state.STD',
|
104
|
+
'Synapse': 'brainpy.state.Synapse',
|
105
|
+
'Expon': 'brainpy.state.Expon',
|
106
|
+
'DualExpon': 'brainpy.state.DualExpon',
|
107
|
+
'Alpha': 'brainpy.state.Alpha',
|
108
|
+
'AMPA': 'brainpy.state.AMPA',
|
109
|
+
'GABAa': 'brainpy.state.GABAa',
|
110
|
+
'COBA': 'brainpy.state.COBA',
|
111
|
+
'CUBA': 'brainpy.state.CUBA',
|
112
|
+
'MgBlock': 'brainpy.state.MgBlock',
|
113
|
+
'SynOut': 'brainpy.state.SynOut',
|
114
|
+
'AlignPostProj': 'brainpy.state.AlignPostProj',
|
115
|
+
'DeltaProj': 'brainpy.state.DeltaProj',
|
116
|
+
'CurrentProj': 'brainpy.state.CurrentProj',
|
117
|
+
'align_pre_projection': 'brainpy.state.align_pre_projection',
|
118
|
+
'Projection': 'brainpy.state.Projection',
|
119
|
+
'SymmetryGapJunction': 'brainpy.state.SymmetryGapJunction',
|
120
|
+
'AsymmetryGapJunction': 'brainpy.state.AsymmetryGapJunction',
|
121
121
|
}
|
122
122
|
|
123
123
|
|
124
124
|
def __getattr__(name: str):
|
125
|
+
import warnings
|
126
|
+
if name == 'DynamicsGroup':
|
127
|
+
warnings.warn(
|
128
|
+
f"'brainstate.nn.{name}' is deprecated and will be removed in a future version. "
|
129
|
+
f"Please use 'brainstate.nn.Module' instead.",
|
130
|
+
DeprecationWarning,
|
131
|
+
stacklevel=2
|
132
|
+
)
|
133
|
+
return Module
|
134
|
+
|
125
135
|
if name in _DEPRECATED_NAMES:
|
126
|
-
import warnings
|
127
136
|
new_name = _DEPRECATED_NAMES[name]
|
128
137
|
warnings.warn(
|
129
138
|
f"'brainstate.nn.{name}' is deprecated and will be removed in a future version. "
|
@@ -133,5 +142,5 @@ def __getattr__(name: str):
|
|
133
142
|
)
|
134
143
|
# Import and return the actual brainpy object
|
135
144
|
import brainpy
|
136
|
-
return getattr(brainpy, name)
|
145
|
+
return getattr(brainpy.state, name)
|
137
146
|
raise AttributeError(f"module 'brainstate.nn' has no attribute '{name}'")
|
brainstate/nn/_collective_ops.py
CHANGED
@@ -12,6 +12,8 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
|
16
|
+
|
15
17
|
import warnings
|
16
18
|
from collections.abc import Sequence, Mapping
|
17
19
|
from typing import Callable, TypeVar, Any
|
brainstate/nn/_common_test.py
CHANGED
@@ -132,23 +132,3 @@ class TestFilterStates(unittest.TestCase):
|
|
132
132
|
|
133
133
|
self.mock_module.states.assert_called_once_with(filter_obj)
|
134
134
|
self.assertEqual(result, ['test1', 'test2'])
|
135
|
-
|
136
|
-
def test_filter_states_dict_filters(self):
|
137
|
-
"""Test _filter_states with dictionary of filters.
|
138
|
-
|
139
|
-
Note: Current implementation expects dict to be iterable as tuples,
|
140
|
-
which suggests it's meant to be passed as a dict that yields tuples when iterated.
|
141
|
-
This is likely a bug - should use filters.items().
|
142
|
-
"""
|
143
|
-
# Skip this test as the current implementation has a bug
|
144
|
-
self.skipTest("Current implementation has a bug in dict iteration")
|
145
|
-
|
146
|
-
def test_filter_states_dict_invalid_axis(self):
|
147
|
-
"""Test _filter_states with non-integer axis in dictionary."""
|
148
|
-
# Skip this test as the current implementation has a bug in dict iteration
|
149
|
-
self.skipTest("Current implementation has a bug in dict iteration")
|
150
|
-
|
151
|
-
def test_filter_states_dict_multiple_filters_same_axis(self):
|
152
|
-
"""Test _filter_states with multiple filters for the same axis."""
|
153
|
-
# Skip this test as the current implementation has a bug in dict iteration
|
154
|
-
self.skipTest("Current implementation has a bug in dict iteration")
|
brainstate/nn/_delay.py
CHANGED
@@ -285,7 +285,7 @@ class Delay(Module):
|
|
285
285
|
Returns:
|
286
286
|
DelayAccess: An object that provides access to the delay data for the specified entry and time.
|
287
287
|
"""
|
288
|
-
return DelayAccess(self, delay_time, entry=entry)
|
288
|
+
return DelayAccess(self, *delay_time, entry=entry)
|
289
289
|
|
290
290
|
def at(self, entry: str) -> ArrayLike:
|
291
291
|
"""
|
brainstate/nn/_dropout_test.py
CHANGED
@@ -146,16 +146,19 @@ class TestDropout1d(parameterized.TestCase):
|
|
146
146
|
|
147
147
|
|
148
148
|
class TestDropout2d(parameterized.TestCase):
|
149
|
+
def setUp(self):
|
150
|
+
brainstate.random.seed(0)
|
149
151
|
|
150
152
|
def test_dropout2d_basic(self):
|
151
153
|
"""Test basic Dropout2d functionality."""
|
152
|
-
|
153
|
-
|
154
|
+
with brainstate.random.seed_context(42):
|
155
|
+
dropout_layer = brainstate.nn.Dropout2d(prob=0.5)
|
156
|
+
input_data = brainstate.random.randn(2, 3, 4, 5) # (N, C, H, W)
|
154
157
|
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
158
|
+
with brainstate.environ.context(fit=True):
|
159
|
+
output_data = dropout_layer(input_data)
|
160
|
+
self.assertEqual(input_data.shape, output_data.shape)
|
161
|
+
self.assertTrue(np.any(output_data == 0))
|
159
162
|
|
160
163
|
def test_dropout2d_channel_wise(self):
|
161
164
|
"""Test that Dropout2d applies dropout."""
|