brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/nn/_common_test.py
CHANGED
@@ -1,154 +1,134 @@
|
|
1
|
-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import unittest
|
17
|
-
from unittest.mock import Mock, patch
|
18
|
-
|
19
|
-
import jax.numpy as jnp
|
20
|
-
|
21
|
-
import brainstate
|
22
|
-
from brainstate import environ
|
23
|
-
from brainstate.nn import Module, EnvironContext
|
24
|
-
from brainstate.nn._common import _filter_states
|
25
|
-
|
26
|
-
|
27
|
-
class DummyModule(Module):
|
28
|
-
"""A simple module for testing purposes."""
|
29
|
-
|
30
|
-
def __init__(self, value=0):
|
31
|
-
super().__init__()
|
32
|
-
self.value = value
|
33
|
-
self.state = brainstate.State(jnp.array([1.0, 2.0, 3.0]))
|
34
|
-
self.param = brainstate.ParamState(jnp.array([4.0, 5.0, 6.0]))
|
35
|
-
|
36
|
-
def update(self, x):
|
37
|
-
return x + self.value
|
38
|
-
|
39
|
-
def __call__(self, x, y=0):
|
40
|
-
return x + self.value + y
|
41
|
-
|
42
|
-
|
43
|
-
class TestEnvironContext(unittest.TestCase):
|
44
|
-
"""Test cases for EnvironContext class."""
|
45
|
-
|
46
|
-
def setUp(self):
|
47
|
-
"""Set up test fixtures."""
|
48
|
-
self.dummy_module = DummyModule(10)
|
49
|
-
|
50
|
-
def test_init_valid_module(self):
|
51
|
-
"""Test EnvironContext initialization with valid module."""
|
52
|
-
context = EnvironContext(self.dummy_module, fit=True, a='test')
|
53
|
-
self.assertEqual(context.layer, self.dummy_module)
|
54
|
-
self.assertEqual(context.context, {'fit': True, 'a': 'test'})
|
55
|
-
|
56
|
-
def test_init_invalid_module(self):
|
57
|
-
"""Test EnvironContext initialization with invalid module."""
|
58
|
-
with self.assertRaises(AssertionError):
|
59
|
-
EnvironContext("not a module", training=True)
|
60
|
-
|
61
|
-
with self.assertRaises(AssertionError):
|
62
|
-
EnvironContext(None, training=True)
|
63
|
-
|
64
|
-
with self.assertRaises(AssertionError):
|
65
|
-
EnvironContext(42, training=True)
|
66
|
-
|
67
|
-
def test_update_with_context(self):
|
68
|
-
"""Test update method applies context correctly."""
|
69
|
-
context = EnvironContext(self.dummy_module, fit=True)
|
70
|
-
|
71
|
-
# Test with positional arguments
|
72
|
-
result = context.update(5)
|
73
|
-
self.assertEqual(result, 15) # 5 + 10
|
74
|
-
|
75
|
-
# Test with keyword arguments
|
76
|
-
result = context.update(5, y=3)
|
77
|
-
self.assertEqual(result, 18) # 5 + 10 + 3
|
78
|
-
|
79
|
-
def test_update_context_applied(self):
|
80
|
-
"""Test that environment context is actually applied during update."""
|
81
|
-
with patch.object(environ, 'context') as mock_context:
|
82
|
-
mock_context.return_value.__enter__ = Mock(return_value=None)
|
83
|
-
mock_context.return_value.__exit__ = Mock(return_value=None)
|
84
|
-
|
85
|
-
context = EnvironContext(self.dummy_module, fit=True, a='eval')
|
86
|
-
context.update(5)
|
87
|
-
|
88
|
-
mock_context.assert_called_once_with(fit=True, a='eval')
|
89
|
-
|
90
|
-
def test_add_context(self):
|
91
|
-
"""Test add_context method updates context correctly."""
|
92
|
-
context = EnvironContext(self.dummy_module, fit=True)
|
93
|
-
self.assertEqual(context.context, {'fit': True})
|
94
|
-
|
95
|
-
# Add new context
|
96
|
-
context.add_context(a='test', debug=False)
|
97
|
-
self.assertEqual(context.context, {'fit': True, 'a': 'test', 'debug': False})
|
98
|
-
|
99
|
-
# Overwrite existing context
|
100
|
-
context.add_context(fit=False)
|
101
|
-
self.assertEqual(context.context, {'fit': False, 'a': 'test', 'debug': False})
|
102
|
-
|
103
|
-
def test_empty_context(self):
|
104
|
-
"""Test EnvironContext with no initial context."""
|
105
|
-
context = EnvironContext(self.dummy_module)
|
106
|
-
self.assertEqual(context.context, {})
|
107
|
-
|
108
|
-
result = context.update(7)
|
109
|
-
self.assertEqual(result, 17) # 7 + 10
|
110
|
-
|
111
|
-
|
112
|
-
class TestFilterStates(unittest.TestCase):
|
113
|
-
"""Test cases for _filter_states function."""
|
114
|
-
|
115
|
-
def setUp(self):
|
116
|
-
"""Set up test fixtures."""
|
117
|
-
self.mock_module = Mock(spec=Module)
|
118
|
-
self.mock_module.states = Mock()
|
119
|
-
|
120
|
-
def test_filter_states_none(self):
|
121
|
-
"""Test _filter_states with None filters."""
|
122
|
-
result = _filter_states(self.mock_module, None)
|
123
|
-
self.assertIsNone(result)
|
124
|
-
self.mock_module.states.assert_not_called()
|
125
|
-
|
126
|
-
def test_filter_states_single_filter(self):
|
127
|
-
"""Test _filter_states with single filter (non-dict)."""
|
128
|
-
filter_obj = lambda x: x.startswith('test')
|
129
|
-
self.mock_module.states.return_value = ['test1', 'test2']
|
130
|
-
|
131
|
-
result = _filter_states(self.mock_module, filter_obj)
|
132
|
-
|
133
|
-
self.mock_module.states.assert_called_once_with(filter_obj)
|
134
|
-
self.assertEqual(result, ['test1', 'test2'])
|
135
|
-
|
136
|
-
def test_filter_states_dict_filters(self):
|
137
|
-
"""Test _filter_states with dictionary of filters.
|
138
|
-
|
139
|
-
Note: Current implementation expects dict to be iterable as tuples,
|
140
|
-
which suggests it's meant to be passed as a dict that yields tuples when iterated.
|
141
|
-
This is likely a bug - should use filters.items().
|
142
|
-
"""
|
143
|
-
# Skip this test as the current implementation has a bug
|
144
|
-
self.skipTest("Current implementation has a bug in dict iteration")
|
145
|
-
|
146
|
-
def test_filter_states_dict_invalid_axis(self):
|
147
|
-
"""Test _filter_states with non-integer axis in dictionary."""
|
148
|
-
# Skip this test as the current implementation has a bug in dict iteration
|
149
|
-
self.skipTest("Current implementation has a bug in dict iteration")
|
150
|
-
|
151
|
-
def test_filter_states_dict_multiple_filters_same_axis(self):
|
152
|
-
"""Test _filter_states with multiple filters for the same axis."""
|
153
|
-
# Skip this test as the current implementation has a bug in dict iteration
|
154
|
-
self.skipTest("Current implementation has a bug in dict iteration")
|
1
|
+
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
from unittest.mock import Mock, patch
|
18
|
+
|
19
|
+
import jax.numpy as jnp
|
20
|
+
|
21
|
+
import brainstate
|
22
|
+
from brainstate import environ
|
23
|
+
from brainstate.nn import Module, EnvironContext
|
24
|
+
from brainstate.nn._common import _filter_states
|
25
|
+
|
26
|
+
|
27
|
+
class DummyModule(Module):
|
28
|
+
"""A simple module for testing purposes."""
|
29
|
+
|
30
|
+
def __init__(self, value=0):
|
31
|
+
super().__init__()
|
32
|
+
self.value = value
|
33
|
+
self.state = brainstate.State(jnp.array([1.0, 2.0, 3.0]))
|
34
|
+
self.param = brainstate.ParamState(jnp.array([4.0, 5.0, 6.0]))
|
35
|
+
|
36
|
+
def update(self, x):
|
37
|
+
return x + self.value
|
38
|
+
|
39
|
+
def __call__(self, x, y=0):
|
40
|
+
return x + self.value + y
|
41
|
+
|
42
|
+
|
43
|
+
class TestEnvironContext(unittest.TestCase):
|
44
|
+
"""Test cases for EnvironContext class."""
|
45
|
+
|
46
|
+
def setUp(self):
|
47
|
+
"""Set up test fixtures."""
|
48
|
+
self.dummy_module = DummyModule(10)
|
49
|
+
|
50
|
+
def test_init_valid_module(self):
|
51
|
+
"""Test EnvironContext initialization with valid module."""
|
52
|
+
context = EnvironContext(self.dummy_module, fit=True, a='test')
|
53
|
+
self.assertEqual(context.layer, self.dummy_module)
|
54
|
+
self.assertEqual(context.context, {'fit': True, 'a': 'test'})
|
55
|
+
|
56
|
+
def test_init_invalid_module(self):
|
57
|
+
"""Test EnvironContext initialization with invalid module."""
|
58
|
+
with self.assertRaises(AssertionError):
|
59
|
+
EnvironContext("not a module", training=True)
|
60
|
+
|
61
|
+
with self.assertRaises(AssertionError):
|
62
|
+
EnvironContext(None, training=True)
|
63
|
+
|
64
|
+
with self.assertRaises(AssertionError):
|
65
|
+
EnvironContext(42, training=True)
|
66
|
+
|
67
|
+
def test_update_with_context(self):
|
68
|
+
"""Test update method applies context correctly."""
|
69
|
+
context = EnvironContext(self.dummy_module, fit=True)
|
70
|
+
|
71
|
+
# Test with positional arguments
|
72
|
+
result = context.update(5)
|
73
|
+
self.assertEqual(result, 15) # 5 + 10
|
74
|
+
|
75
|
+
# Test with keyword arguments
|
76
|
+
result = context.update(5, y=3)
|
77
|
+
self.assertEqual(result, 18) # 5 + 10 + 3
|
78
|
+
|
79
|
+
def test_update_context_applied(self):
|
80
|
+
"""Test that environment context is actually applied during update."""
|
81
|
+
with patch.object(environ, 'context') as mock_context:
|
82
|
+
mock_context.return_value.__enter__ = Mock(return_value=None)
|
83
|
+
mock_context.return_value.__exit__ = Mock(return_value=None)
|
84
|
+
|
85
|
+
context = EnvironContext(self.dummy_module, fit=True, a='eval')
|
86
|
+
context.update(5)
|
87
|
+
|
88
|
+
mock_context.assert_called_once_with(fit=True, a='eval')
|
89
|
+
|
90
|
+
def test_add_context(self):
|
91
|
+
"""Test add_context method updates context correctly."""
|
92
|
+
context = EnvironContext(self.dummy_module, fit=True)
|
93
|
+
self.assertEqual(context.context, {'fit': True})
|
94
|
+
|
95
|
+
# Add new context
|
96
|
+
context.add_context(a='test', debug=False)
|
97
|
+
self.assertEqual(context.context, {'fit': True, 'a': 'test', 'debug': False})
|
98
|
+
|
99
|
+
# Overwrite existing context
|
100
|
+
context.add_context(fit=False)
|
101
|
+
self.assertEqual(context.context, {'fit': False, 'a': 'test', 'debug': False})
|
102
|
+
|
103
|
+
def test_empty_context(self):
|
104
|
+
"""Test EnvironContext with no initial context."""
|
105
|
+
context = EnvironContext(self.dummy_module)
|
106
|
+
self.assertEqual(context.context, {})
|
107
|
+
|
108
|
+
result = context.update(7)
|
109
|
+
self.assertEqual(result, 17) # 7 + 10
|
110
|
+
|
111
|
+
|
112
|
+
class TestFilterStates(unittest.TestCase):
|
113
|
+
"""Test cases for _filter_states function."""
|
114
|
+
|
115
|
+
def setUp(self):
|
116
|
+
"""Set up test fixtures."""
|
117
|
+
self.mock_module = Mock(spec=Module)
|
118
|
+
self.mock_module.states = Mock()
|
119
|
+
|
120
|
+
def test_filter_states_none(self):
|
121
|
+
"""Test _filter_states with None filters."""
|
122
|
+
result = _filter_states(self.mock_module, None)
|
123
|
+
self.assertIsNone(result)
|
124
|
+
self.mock_module.states.assert_not_called()
|
125
|
+
|
126
|
+
def test_filter_states_single_filter(self):
|
127
|
+
"""Test _filter_states with single filter (non-dict)."""
|
128
|
+
filter_obj = lambda x: x.startswith('test')
|
129
|
+
self.mock_module.states.return_value = ['test1', 'test2']
|
130
|
+
|
131
|
+
result = _filter_states(self.mock_module, filter_obj)
|
132
|
+
|
133
|
+
self.mock_module.states.assert_called_once_with(filter_obj)
|
134
|
+
self.assertEqual(result, ['test1', 'test2'])
|