brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2025
|
1
|
+
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -16,28 +16,759 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
18
|
|
19
|
+
import jax.numpy as jnp
|
20
|
+
import pytest
|
21
|
+
|
19
22
|
import brainstate
|
20
23
|
|
21
24
|
|
22
|
-
class
|
25
|
+
class SimpleTestModule(brainstate.nn.Module):
|
26
|
+
"""Simple test module with init_state method"""
|
23
27
|
|
24
|
-
def
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
+
def __init__(self):
|
29
|
+
super().__init__()
|
30
|
+
self.state_initialized = False
|
31
|
+
self.init_args = None
|
32
|
+
self.init_kwargs = None
|
33
|
+
|
34
|
+
def init_state(self, *args, **kwargs):
|
35
|
+
self.state_initialized = True
|
36
|
+
self.init_args = args
|
37
|
+
self.init_kwargs = kwargs
|
38
|
+
self.state = brainstate.State(jnp.zeros(5))
|
39
|
+
|
40
|
+
|
41
|
+
class OrderedTestModule(brainstate.nn.Module):
|
42
|
+
"""Test module with call_order decorator"""
|
43
|
+
|
44
|
+
def __init__(self, order_level):
|
45
|
+
super().__init__()
|
46
|
+
self.order_level = order_level
|
47
|
+
self.call_sequence = []
|
48
|
+
|
49
|
+
@brainstate.nn.call_order(0)
|
50
|
+
def init_state(self):
|
51
|
+
self.state = brainstate.State(jnp.array([self.order_level]))
|
28
52
|
|
29
|
-
def test_vmap_init_all_states_v2(self):
|
30
|
-
@brainstate.compile.jit
|
31
|
-
def init():
|
32
|
-
gru = brainstate.nn.GRUCell(1, 2)
|
33
|
-
brainstate.nn.vmap_init_all_states(gru, axis_size=10)
|
34
|
-
print(gru)
|
35
53
|
|
36
|
-
|
54
|
+
class NestedModule(brainstate.nn.Module):
|
55
|
+
"""Module with nested submodules"""
|
56
|
+
|
57
|
+
def __init__(self):
|
58
|
+
super().__init__()
|
59
|
+
self.submodule1 = SimpleTestModule()
|
60
|
+
self.submodule2 = SimpleTestModule()
|
61
|
+
|
62
|
+
def init_state(self):
|
63
|
+
self.state = brainstate.State(jnp.zeros(3))
|
37
64
|
|
38
65
|
|
39
66
|
class Test_init_all_states:
|
40
|
-
|
67
|
+
"""Comprehensive tests for init_all_states function"""
|
68
|
+
|
69
|
+
def test_basic_initialization(self):
|
70
|
+
"""Test basic state initialization"""
|
71
|
+
module = SimpleTestModule()
|
72
|
+
assert not module.state_initialized
|
73
|
+
|
74
|
+
brainstate.nn.init_all_states(module)
|
75
|
+
|
76
|
+
assert module.state_initialized
|
77
|
+
assert module.state.value.shape == (5,)
|
78
|
+
|
79
|
+
def test_with_positional_args(self):
|
80
|
+
"""Test init_all_states with positional arguments"""
|
81
|
+
module = SimpleTestModule()
|
82
|
+
|
83
|
+
brainstate.nn.init_all_states(module, 1, 2, 3)
|
84
|
+
|
85
|
+
assert module.state_initialized
|
86
|
+
assert module.init_args == (1, 2, 3)
|
87
|
+
|
88
|
+
def test_with_keyword_args(self):
|
89
|
+
"""Test init_all_states with keyword arguments"""
|
90
|
+
module = SimpleTestModule()
|
91
|
+
|
92
|
+
brainstate.nn.init_all_states(module, batch_size=10, seq_len=20)
|
93
|
+
|
94
|
+
assert module.state_initialized
|
95
|
+
assert module.init_kwargs == {'batch_size': 10, 'seq_len': 20}
|
96
|
+
|
97
|
+
def test_with_mixed_args(self):
|
98
|
+
"""Test init_all_states with both positional and keyword arguments"""
|
99
|
+
module = SimpleTestModule()
|
100
|
+
|
101
|
+
brainstate.nn.init_all_states(module, 42, batch_size=10)
|
102
|
+
|
103
|
+
assert module.state_initialized
|
104
|
+
assert module.init_args == (42,)
|
105
|
+
assert module.init_kwargs == {'batch_size': 10}
|
106
|
+
|
107
|
+
def test_nested_modules(self):
|
108
|
+
"""Test that init_all_states initializes nested submodules"""
|
109
|
+
module = NestedModule()
|
110
|
+
|
111
|
+
brainstate.nn.init_all_states(module)
|
112
|
+
|
113
|
+
assert module.submodule1.state_initialized
|
114
|
+
assert module.submodule2.state_initialized
|
115
|
+
assert hasattr(module, 'state')
|
116
|
+
|
117
|
+
def test_with_gru_cell(self):
|
118
|
+
"""Test with real GRUCell module"""
|
41
119
|
gru = brainstate.nn.GRUCell(1, 2)
|
120
|
+
|
42
121
|
brainstate.nn.init_all_states(gru, batch_size=10)
|
43
|
-
|
122
|
+
|
123
|
+
# Check that states were created
|
124
|
+
state_dict = gru.states()
|
125
|
+
assert len(state_dict) > 0
|
126
|
+
|
127
|
+
def test_sequential_module(self):
|
128
|
+
"""Test with Sequential module containing multiple layers"""
|
129
|
+
seq = brainstate.nn.Sequential(
|
130
|
+
brainstate.nn.Linear(10, 20),
|
131
|
+
brainstate.nn.Dropout(0.5)
|
132
|
+
)
|
133
|
+
|
134
|
+
brainstate.nn.init_all_states(seq)
|
135
|
+
|
136
|
+
# Check that Linear layer has weight and bias
|
137
|
+
state_dict = seq.states()
|
138
|
+
assert len(state_dict) > 0
|
139
|
+
|
140
|
+
def test_node_to_exclude(self):
|
141
|
+
"""Test excluding specific nodes from initialization"""
|
142
|
+
module = NestedModule()
|
143
|
+
|
144
|
+
# Exclude submodule1 by type matching - simpler and more reliable
|
145
|
+
brainstate.nn.init_all_states(
|
146
|
+
module,
|
147
|
+
node_to_exclude=NestedModule # Exclude the parent, only init children
|
148
|
+
)
|
149
|
+
|
150
|
+
# Parent should not be initialized, but children should be
|
151
|
+
assert not hasattr(module, 'state') or module.state is None or not hasattr(module.state, 'value')
|
152
|
+
assert module.submodule1.state_initialized
|
153
|
+
assert module.submodule2.state_initialized
|
154
|
+
|
155
|
+
def test_with_call_order(self):
|
156
|
+
"""Test that call_order is respected during initialization"""
|
157
|
+
|
158
|
+
class OrderedModule(brainstate.nn.Module):
|
159
|
+
def __init__(self):
|
160
|
+
super().__init__()
|
161
|
+
self.execution_order = []
|
162
|
+
|
163
|
+
@brainstate.nn.call_order(1)
|
164
|
+
def init_state(self):
|
165
|
+
self.execution_order.append('parent')
|
166
|
+
|
167
|
+
class ChildModule(brainstate.nn.Module):
|
168
|
+
def __init__(self, parent_module):
|
169
|
+
super().__init__()
|
170
|
+
self.parent = parent_module
|
171
|
+
|
172
|
+
@brainstate.nn.call_order(0)
|
173
|
+
def init_state(self):
|
174
|
+
self.parent.execution_order.append('child')
|
175
|
+
|
176
|
+
parent = OrderedModule()
|
177
|
+
child = ChildModule(parent)
|
178
|
+
parent.child_module = child
|
179
|
+
|
180
|
+
brainstate.nn.init_all_states(parent)
|
181
|
+
|
182
|
+
# Child (order 0) should execute before parent (order 1)
|
183
|
+
assert parent.execution_order == ['child', 'parent']
|
184
|
+
|
185
|
+
|
186
|
+
class ResetTestModule(brainstate.nn.Module):
|
187
|
+
"""Test module with both init_state and reset_state methods"""
|
188
|
+
|
189
|
+
def __init__(self):
|
190
|
+
super().__init__()
|
191
|
+
self.state_initialized = False
|
192
|
+
self.state_reset = False
|
193
|
+
self.reset_args = None
|
194
|
+
self.reset_kwargs = None
|
195
|
+
|
196
|
+
def init_state(self, *args, **kwargs):
|
197
|
+
self.state_initialized = True
|
198
|
+
self.state_reset = False
|
199
|
+
self.state = brainstate.State(jnp.ones(5))
|
200
|
+
|
201
|
+
def reset_state(self, *args, **kwargs):
|
202
|
+
self.state_reset = True
|
203
|
+
self.reset_args = args
|
204
|
+
self.reset_kwargs = kwargs
|
205
|
+
if hasattr(self, 'state'):
|
206
|
+
self.state.value = jnp.zeros(5)
|
207
|
+
|
208
|
+
|
209
|
+
class ResetOrderedModule(brainstate.nn.Module):
|
210
|
+
"""Test module with call_order on reset_state"""
|
211
|
+
|
212
|
+
def __init__(self, order_level, execution_tracker):
|
213
|
+
super().__init__()
|
214
|
+
self.order_level = order_level
|
215
|
+
self.execution_tracker = execution_tracker
|
216
|
+
|
217
|
+
def init_state(self):
|
218
|
+
self.state = brainstate.State(jnp.ones(3))
|
219
|
+
|
220
|
+
@brainstate.nn.call_order(0)
|
221
|
+
def reset_state(self):
|
222
|
+
self.execution_tracker.append(f'order_{self.order_level}')
|
223
|
+
self.state.value = jnp.zeros(3)
|
224
|
+
|
225
|
+
|
226
|
+
class NestedResetModule(brainstate.nn.Module):
|
227
|
+
"""Module with nested submodules that have reset_state"""
|
228
|
+
|
229
|
+
def __init__(self):
|
230
|
+
super().__init__()
|
231
|
+
self.submodule1 = ResetTestModule()
|
232
|
+
self.submodule2 = ResetTestModule()
|
233
|
+
|
234
|
+
def init_state(self):
|
235
|
+
self.state = brainstate.State(jnp.ones(3))
|
236
|
+
|
237
|
+
def reset_state(self):
|
238
|
+
self.state.value = jnp.zeros(3)
|
239
|
+
|
240
|
+
|
241
|
+
class Test_reset_all_states:
|
242
|
+
"""Comprehensive tests for reset_all_states function"""
|
243
|
+
|
244
|
+
def test_basic_reset(self):
|
245
|
+
"""Test basic state reset"""
|
246
|
+
module = ResetTestModule()
|
247
|
+
brainstate.nn.init_all_states(module)
|
248
|
+
|
249
|
+
assert module.state_initialized
|
250
|
+
assert not module.state_reset
|
251
|
+
assert jnp.allclose(module.state.value, jnp.ones(5))
|
252
|
+
|
253
|
+
brainstate.nn.reset_all_states(module)
|
254
|
+
|
255
|
+
assert module.state_reset
|
256
|
+
assert jnp.allclose(module.state.value, jnp.zeros(5))
|
257
|
+
|
258
|
+
def test_with_positional_args(self):
|
259
|
+
"""Test reset_all_states with positional arguments"""
|
260
|
+
module = ResetTestModule()
|
261
|
+
brainstate.nn.init_all_states(module)
|
262
|
+
|
263
|
+
brainstate.nn.reset_all_states(module, 1, 2, 3)
|
264
|
+
|
265
|
+
assert module.state_reset
|
266
|
+
assert module.reset_args == (1, 2, 3)
|
267
|
+
|
268
|
+
def test_with_keyword_args(self):
|
269
|
+
"""Test reset_all_states with keyword arguments"""
|
270
|
+
module = ResetTestModule()
|
271
|
+
brainstate.nn.init_all_states(module)
|
272
|
+
|
273
|
+
brainstate.nn.reset_all_states(module, batch_size=10, seq_len=20)
|
274
|
+
|
275
|
+
assert module.state_reset
|
276
|
+
assert module.reset_kwargs == {'batch_size': 10, 'seq_len': 20}
|
277
|
+
|
278
|
+
def test_with_mixed_args(self):
|
279
|
+
"""Test reset_all_states with both positional and keyword arguments"""
|
280
|
+
module = ResetTestModule()
|
281
|
+
brainstate.nn.init_all_states(module)
|
282
|
+
|
283
|
+
brainstate.nn.reset_all_states(module, 42, batch_size=10)
|
284
|
+
|
285
|
+
assert module.state_reset
|
286
|
+
assert module.reset_args == (42,)
|
287
|
+
assert module.reset_kwargs == {'batch_size': 10}
|
288
|
+
|
289
|
+
def test_nested_modules(self):
|
290
|
+
"""Test that reset_all_states resets nested submodules"""
|
291
|
+
module = NestedResetModule()
|
292
|
+
brainstate.nn.init_all_states(module)
|
293
|
+
|
294
|
+
# Verify initial state
|
295
|
+
assert jnp.allclose(module.state.value, jnp.ones(3))
|
296
|
+
assert jnp.allclose(module.submodule1.state.value, jnp.ones(5))
|
297
|
+
assert jnp.allclose(module.submodule2.state.value, jnp.ones(5))
|
298
|
+
|
299
|
+
brainstate.nn.reset_all_states(module)
|
300
|
+
|
301
|
+
# Verify all states were reset
|
302
|
+
assert jnp.allclose(module.state.value, jnp.zeros(3))
|
303
|
+
assert module.submodule1.state_reset
|
304
|
+
assert module.submodule2.state_reset
|
305
|
+
assert jnp.allclose(module.submodule1.state.value, jnp.zeros(5))
|
306
|
+
assert jnp.allclose(module.submodule2.state.value, jnp.zeros(5))
|
307
|
+
|
308
|
+
def test_with_gru_cell(self):
|
309
|
+
"""Test reset with real GRUCell module"""
|
310
|
+
gru = brainstate.nn.GRUCell(5, 10)
|
311
|
+
brainstate.nn.init_all_states(gru, batch_size=8)
|
312
|
+
|
313
|
+
# Get initial state
|
314
|
+
initial_states = {k: v.value.copy() for k, v in gru.states().items()
|
315
|
+
if hasattr(v.value, 'copy') and not isinstance(v.value, dict)}
|
316
|
+
|
317
|
+
# Reset state
|
318
|
+
brainstate.nn.reset_all_states(gru, batch_size=8)
|
319
|
+
|
320
|
+
# Verify state was reset (should be zeros for hidden state)
|
321
|
+
for key in initial_states:
|
322
|
+
current_val = gru.states()[key].value
|
323
|
+
if not isinstance(current_val, dict):
|
324
|
+
# Hidden state should be reset to zeros
|
325
|
+
if 'h' in str(key):
|
326
|
+
assert jnp.allclose(current_val, jnp.zeros_like(current_val))
|
327
|
+
|
328
|
+
def test_sequential_reset(self):
|
329
|
+
"""Test reset with Sequential module"""
|
330
|
+
|
331
|
+
# Create a simple network with stateful components
|
332
|
+
class StatefulLayer(brainstate.nn.Module):
|
333
|
+
def __init__(self):
|
334
|
+
super().__init__()
|
335
|
+
self.reset_called = False
|
336
|
+
|
337
|
+
def init_state(self):
|
338
|
+
self.state = brainstate.State(jnp.ones(5))
|
339
|
+
|
340
|
+
def reset_state(self):
|
341
|
+
self.reset_called = True
|
342
|
+
self.state.value = jnp.zeros(5)
|
343
|
+
|
344
|
+
layer1 = StatefulLayer()
|
345
|
+
layer2 = StatefulLayer()
|
346
|
+
seq = brainstate.nn.Sequential(layer1, layer2)
|
347
|
+
|
348
|
+
brainstate.nn.init_all_states(seq)
|
349
|
+
brainstate.nn.reset_all_states(seq)
|
350
|
+
|
351
|
+
assert layer1.reset_called
|
352
|
+
assert layer2.reset_called
|
353
|
+
|
354
|
+
def test_node_to_exclude(self):
|
355
|
+
"""Test excluding specific nodes from reset"""
|
356
|
+
module = NestedResetModule()
|
357
|
+
brainstate.nn.init_all_states(module)
|
358
|
+
|
359
|
+
# Exclude the parent module from reset
|
360
|
+
brainstate.nn.reset_all_states(
|
361
|
+
module,
|
362
|
+
node_to_exclude=NestedResetModule
|
363
|
+
)
|
364
|
+
|
365
|
+
# Parent should not be reset, but children should be
|
366
|
+
assert jnp.allclose(module.state.value, jnp.ones(3)) # Not reset
|
367
|
+
assert module.submodule1.state_reset # Reset
|
368
|
+
assert module.submodule2.state_reset # Reset
|
369
|
+
|
370
|
+
def test_with_call_order(self):
|
371
|
+
"""Test that call_order is respected during reset"""
|
372
|
+
execution_tracker = []
|
373
|
+
|
374
|
+
class OrderedParent(brainstate.nn.Module):
|
375
|
+
def __init__(self):
|
376
|
+
super().__init__()
|
377
|
+
self.child1 = ResetOrderedModule(1, execution_tracker)
|
378
|
+
self.child2 = ResetOrderedModule(2, execution_tracker)
|
379
|
+
|
380
|
+
def init_state(self):
|
381
|
+
pass
|
382
|
+
|
383
|
+
parent = OrderedParent()
|
384
|
+
brainstate.nn.init_all_states(parent)
|
385
|
+
|
386
|
+
execution_tracker.clear()
|
387
|
+
brainstate.nn.reset_all_states(parent)
|
388
|
+
|
389
|
+
# Both should execute (order 0), check that reset was called
|
390
|
+
assert len(execution_tracker) == 2
|
391
|
+
|
392
|
+
def test_multiple_resets(self):
|
393
|
+
"""Test calling reset_all_states multiple times"""
|
394
|
+
module = ResetTestModule()
|
395
|
+
brainstate.nn.init_all_states(module)
|
396
|
+
|
397
|
+
for i in range(3):
|
398
|
+
brainstate.nn.reset_all_states(module)
|
399
|
+
assert module.state_reset
|
400
|
+
assert jnp.allclose(module.state.value, jnp.zeros(5))
|
401
|
+
|
402
|
+
def test_reset_without_init(self):
|
403
|
+
"""Test that reset works even if init wasn't called explicitly"""
|
404
|
+
gru = brainstate.nn.GRUCell(5, 10)
|
405
|
+
|
406
|
+
# Initialize first
|
407
|
+
brainstate.nn.init_all_states(gru, batch_size=8)
|
408
|
+
|
409
|
+
# Reset should work
|
410
|
+
brainstate.nn.reset_all_states(gru, batch_size=8)
|
411
|
+
|
412
|
+
# Verify it didn't crash
|
413
|
+
states = gru.states()
|
414
|
+
assert len(states) > 0
|
415
|
+
|
416
|
+
|
417
|
+
class CustomMethodModule(brainstate.nn.Module):
|
418
|
+
"""Test module with custom methods for call_all_fns testing"""
|
419
|
+
|
420
|
+
def __init__(self):
|
421
|
+
super().__init__()
|
422
|
+
self.method_called = False
|
423
|
+
self.call_count = 0
|
424
|
+
self.received_args = None
|
425
|
+
self.received_kwargs = None
|
426
|
+
|
427
|
+
def custom_method(self, *args, **kwargs):
|
428
|
+
self.method_called = True
|
429
|
+
self.call_count += 1
|
430
|
+
self.received_args = args
|
431
|
+
self.received_kwargs = kwargs
|
432
|
+
|
433
|
+
def another_method(self):
|
434
|
+
self.call_count += 10
|
435
|
+
|
436
|
+
|
437
|
+
class OrderedCallModule(brainstate.nn.Module):
|
438
|
+
"""Test module with ordered methods"""
|
439
|
+
|
440
|
+
def __init__(self, execution_tracker):
|
441
|
+
super().__init__()
|
442
|
+
self.execution_tracker = execution_tracker
|
443
|
+
|
444
|
+
@brainstate.nn.call_order(0)
|
445
|
+
def ordered_method_0(self):
|
446
|
+
self.execution_tracker.append('order_0')
|
447
|
+
|
448
|
+
@brainstate.nn.call_order(1)
|
449
|
+
def ordered_method_1(self):
|
450
|
+
self.execution_tracker.append('order_1')
|
451
|
+
|
452
|
+
@brainstate.nn.call_order(2)
|
453
|
+
def ordered_method_2(self):
|
454
|
+
self.execution_tracker.append('order_2')
|
455
|
+
|
456
|
+
def unordered_method(self):
|
457
|
+
self.execution_tracker.append('unordered')
|
458
|
+
|
459
|
+
|
460
|
+
class NestedCallModule(brainstate.nn.Module):
|
461
|
+
"""Module with nested submodules for call testing"""
|
462
|
+
|
463
|
+
def __init__(self):
|
464
|
+
super().__init__()
|
465
|
+
self.child1 = CustomMethodModule()
|
466
|
+
self.child2 = CustomMethodModule()
|
467
|
+
self.method_called = False
|
468
|
+
|
469
|
+
def custom_method(self, *args, **kwargs):
|
470
|
+
self.method_called = True
|
471
|
+
|
472
|
+
|
473
|
+
class Test_call_order:
|
474
|
+
"""Comprehensive tests for call_order decorator"""
|
475
|
+
|
476
|
+
def test_basic_call_order(self):
|
477
|
+
"""Test basic call_order decorator"""
|
478
|
+
execution_tracker = []
|
479
|
+
|
480
|
+
class TestModule(brainstate.nn.Module):
|
481
|
+
@brainstate.nn.call_order(0)
|
482
|
+
def method(self):
|
483
|
+
execution_tracker.append('executed')
|
484
|
+
|
485
|
+
module = TestModule()
|
486
|
+
assert hasattr(module.method, 'call_order')
|
487
|
+
assert module.method.call_order == 0
|
488
|
+
|
489
|
+
def test_different_order_levels(self):
|
490
|
+
"""Test different order levels"""
|
491
|
+
for level in [0, 1, 5, 9]:
|
492
|
+
class TestModule(brainstate.nn.Module):
|
493
|
+
@brainstate.nn.call_order(level)
|
494
|
+
def method(self):
|
495
|
+
pass
|
496
|
+
|
497
|
+
module = TestModule()
|
498
|
+
assert module.method.call_order == level
|
499
|
+
|
500
|
+
def test_order_boundary_validation(self):
|
501
|
+
"""Test that order level boundary validation works"""
|
502
|
+
# Valid levels (0 to MAX_ORDER-1)
|
503
|
+
for level in range(brainstate.nn._collective_ops.MAX_ORDER):
|
504
|
+
@brainstate.nn.call_order(level)
|
505
|
+
def valid_method():
|
506
|
+
pass
|
507
|
+
|
508
|
+
assert valid_method.call_order == level
|
509
|
+
|
510
|
+
# Invalid levels
|
511
|
+
with pytest.raises(ValueError, match="must be an integer"):
|
512
|
+
@brainstate.nn.call_order(-1)
|
513
|
+
def invalid_method1():
|
514
|
+
pass
|
515
|
+
|
516
|
+
with pytest.raises(ValueError, match="must be an integer"):
|
517
|
+
@brainstate.nn.call_order(brainstate.nn._collective_ops.MAX_ORDER)
|
518
|
+
def invalid_method2():
|
519
|
+
pass
|
520
|
+
|
521
|
+
def test_disable_boundary_check(self):
|
522
|
+
"""Test disabling boundary check"""
|
523
|
+
|
524
|
+
@brainstate.nn.call_order(100, check_order_boundary=False)
|
525
|
+
def method():
|
526
|
+
pass
|
527
|
+
|
528
|
+
assert method.call_order == 100
|
529
|
+
|
530
|
+
@brainstate.nn.call_order(-5, check_order_boundary=False)
|
531
|
+
def method2():
|
532
|
+
pass
|
533
|
+
|
534
|
+
assert method2.call_order == -5
|
535
|
+
|
536
|
+
def test_order_preserved_on_methods(self):
|
537
|
+
"""Test that call_order is preserved on instance methods"""
|
538
|
+
execution_tracker = []
|
539
|
+
module = OrderedCallModule(execution_tracker)
|
540
|
+
|
541
|
+
assert module.ordered_method_0.call_order == 0
|
542
|
+
assert module.ordered_method_1.call_order == 1
|
543
|
+
assert module.ordered_method_2.call_order == 2
|
544
|
+
assert not hasattr(module.unordered_method, 'call_order')
|
545
|
+
|
546
|
+
def test_multiple_decorators(self):
|
547
|
+
"""Test applying call_order to multiple methods"""
|
548
|
+
execution_tracker = []
|
549
|
+
|
550
|
+
class MultiMethodModule(brainstate.nn.Module):
|
551
|
+
@brainstate.nn.call_order(2)
|
552
|
+
def method_a(self):
|
553
|
+
execution_tracker.append('a')
|
554
|
+
|
555
|
+
@brainstate.nn.call_order(0)
|
556
|
+
def method_b(self):
|
557
|
+
execution_tracker.append('b')
|
558
|
+
|
559
|
+
@brainstate.nn.call_order(1)
|
560
|
+
def method_c(self):
|
561
|
+
execution_tracker.append('c')
|
562
|
+
|
563
|
+
module = MultiMethodModule()
|
564
|
+
assert module.method_a.call_order == 2
|
565
|
+
assert module.method_b.call_order == 0
|
566
|
+
assert module.method_c.call_order == 1
|
567
|
+
|
568
|
+
|
569
|
+
class Test_call_all_fns:
|
570
|
+
"""Comprehensive tests for call_all_fns function"""
|
571
|
+
|
572
|
+
def test_basic_function_call(self):
|
573
|
+
"""Test basic function calling"""
|
574
|
+
module = CustomMethodModule()
|
575
|
+
|
576
|
+
assert not module.method_called
|
577
|
+
brainstate.nn.call_all_fns(module, 'custom_method')
|
578
|
+
assert module.method_called
|
579
|
+
assert module.call_count == 1
|
580
|
+
|
581
|
+
def test_with_positional_args(self):
|
582
|
+
"""Test call_all_fns with positional arguments"""
|
583
|
+
module = CustomMethodModule()
|
584
|
+
|
585
|
+
brainstate.nn.call_all_fns(module, 'custom_method', (1, 2, 3))
|
586
|
+
|
587
|
+
assert module.method_called
|
588
|
+
assert module.received_args == (1, 2, 3)
|
589
|
+
|
590
|
+
def test_with_keyword_args(self):
|
591
|
+
"""Test call_all_fns with keyword arguments"""
|
592
|
+
module = CustomMethodModule()
|
593
|
+
|
594
|
+
brainstate.nn.call_all_fns(module, 'custom_method', kwargs={'foo': 'bar', 'baz': 42})
|
595
|
+
|
596
|
+
assert module.method_called
|
597
|
+
assert module.received_kwargs == {'foo': 'bar', 'baz': 42}
|
598
|
+
|
599
|
+
def test_with_mixed_args(self):
|
600
|
+
"""Test call_all_fns with both positional and keyword arguments"""
|
601
|
+
module = CustomMethodModule()
|
602
|
+
|
603
|
+
brainstate.nn.call_all_fns(
|
604
|
+
module,
|
605
|
+
'custom_method',
|
606
|
+
args=(1, 2),
|
607
|
+
kwargs={'key': 'value'}
|
608
|
+
)
|
609
|
+
|
610
|
+
assert module.method_called
|
611
|
+
assert module.received_args == (1, 2)
|
612
|
+
assert module.received_kwargs == {'key': 'value'}
|
613
|
+
|
614
|
+
def test_nested_modules(self):
|
615
|
+
"""Test that call_all_fns calls methods on nested modules"""
|
616
|
+
module = NestedCallModule()
|
617
|
+
|
618
|
+
brainstate.nn.call_all_fns(module, 'custom_method')
|
619
|
+
|
620
|
+
assert module.method_called
|
621
|
+
assert module.child1.method_called
|
622
|
+
assert module.child2.method_called
|
623
|
+
|
624
|
+
def test_call_order_respected(self):
|
625
|
+
"""Test that call_order is respected"""
|
626
|
+
execution_tracker = []
|
627
|
+
|
628
|
+
class ParentModule(brainstate.nn.Module):
|
629
|
+
def __init__(self):
|
630
|
+
super().__init__()
|
631
|
+
self.child1 = OrderedCallModule(execution_tracker)
|
632
|
+
self.child2 = OrderedCallModule(execution_tracker)
|
633
|
+
|
634
|
+
module = ParentModule()
|
635
|
+
|
636
|
+
# Call ordered_method_1 on all modules (parent doesn't have it, so skip)
|
637
|
+
brainstate.nn.call_all_fns(module, 'ordered_method_1', fn_if_not_exist='pass')
|
638
|
+
|
639
|
+
# Should be called on both children (both have order 1)
|
640
|
+
assert execution_tracker.count('order_1') == 2
|
641
|
+
|
642
|
+
def test_execution_order_with_mixed_decorators(self):
|
643
|
+
"""Test execution order with both ordered and unordered methods"""
|
644
|
+
execution_tracker = []
|
645
|
+
|
646
|
+
class MixedModule(brainstate.nn.Module):
|
647
|
+
def __init__(self):
|
648
|
+
super().__init__()
|
649
|
+
self.ordered = OrderedCallModule(execution_tracker)
|
650
|
+
|
651
|
+
def unordered_method(self):
|
652
|
+
execution_tracker.append('parent_unordered')
|
653
|
+
|
654
|
+
module = MixedModule()
|
655
|
+
|
656
|
+
# Call unordered_method - parent has no decorator, child has no decorator
|
657
|
+
execution_tracker.clear()
|
658
|
+
brainstate.nn.call_all_fns(module, 'unordered_method')
|
659
|
+
|
660
|
+
# Both should be called (unordered methods execute first)
|
661
|
+
assert 'parent_unordered' in execution_tracker
|
662
|
+
assert 'unordered' in execution_tracker
|
663
|
+
|
664
|
+
def test_node_to_exclude(self):
|
665
|
+
"""Test excluding specific nodes"""
|
666
|
+
module = NestedCallModule()
|
667
|
+
|
668
|
+
# Exclude the parent module
|
669
|
+
brainstate.nn.call_all_fns(
|
670
|
+
module,
|
671
|
+
'custom_method',
|
672
|
+
node_to_exclude=NestedCallModule
|
673
|
+
)
|
674
|
+
|
675
|
+
# Parent should not be called, but children should be
|
676
|
+
assert not module.method_called
|
677
|
+
assert module.child1.method_called
|
678
|
+
assert module.child2.method_called
|
679
|
+
|
680
|
+
def test_fn_if_not_exist_raise(self):
|
681
|
+
"""Test fn_if_not_exist='raise' behavior"""
|
682
|
+
module = CustomMethodModule()
|
683
|
+
|
684
|
+
with pytest.raises(AttributeError, match="does not have method"):
|
685
|
+
brainstate.nn.call_all_fns(module, 'nonexistent_method', fn_if_not_exist='raise')
|
686
|
+
|
687
|
+
def test_fn_if_not_exist_pass(self):
|
688
|
+
"""Test fn_if_not_exist='pass' behavior"""
|
689
|
+
module = CustomMethodModule()
|
690
|
+
|
691
|
+
# Should not raise error
|
692
|
+
brainstate.nn.call_all_fns(module, 'nonexistent_method', fn_if_not_exist='pass')
|
693
|
+
|
694
|
+
def test_fn_if_not_exist_none(self):
|
695
|
+
"""Test fn_if_not_exist='none' behavior"""
|
696
|
+
module = CustomMethodModule()
|
697
|
+
|
698
|
+
# Should not raise error
|
699
|
+
brainstate.nn.call_all_fns(module, 'nonexistent_method', fn_if_not_exist='none')
|
700
|
+
|
701
|
+
def test_fn_if_not_exist_warn(self):
|
702
|
+
"""Test fn_if_not_exist='warn' behavior"""
|
703
|
+
module = CustomMethodModule()
|
704
|
+
|
705
|
+
# Should issue warning but not raise
|
706
|
+
with pytest.warns(UserWarning, match="does not have method"):
|
707
|
+
brainstate.nn.call_all_fns(module, 'nonexistent_method', fn_if_not_exist='warn')
|
708
|
+
|
709
|
+
def test_invalid_fn_name_type(self):
|
710
|
+
"""Test that invalid fn_name type raises error"""
|
711
|
+
module = CustomMethodModule()
|
712
|
+
|
713
|
+
with pytest.raises(TypeError, match="fn_name must be a string"):
|
714
|
+
brainstate.nn.call_all_fns(module, 123)
|
715
|
+
|
716
|
+
def test_invalid_kwargs_type(self):
|
717
|
+
"""Test that invalid kwargs type raises error"""
|
718
|
+
module = CustomMethodModule()
|
719
|
+
|
720
|
+
with pytest.raises(TypeError, match="kwargs must be a mapping"):
|
721
|
+
brainstate.nn.call_all_fns(module, 'custom_method', kwargs=[1, 2, 3])
|
722
|
+
|
723
|
+
def test_non_callable_attribute(self):
|
724
|
+
"""Test that non-callable attributes raise error"""
|
725
|
+
|
726
|
+
class ModuleWithAttribute(brainstate.nn.Module):
|
727
|
+
def __init__(self):
|
728
|
+
super().__init__()
|
729
|
+
self.my_attr = "not callable"
|
730
|
+
|
731
|
+
module = ModuleWithAttribute()
|
732
|
+
|
733
|
+
with pytest.raises(TypeError, match="must be callable"):
|
734
|
+
brainstate.nn.call_all_fns(module, 'my_attr')
|
735
|
+
|
736
|
+
def test_with_gru_cell(self):
|
737
|
+
"""Test with real GRU cell"""
|
738
|
+
gru = brainstate.nn.GRUCell(5, 10)
|
739
|
+
|
740
|
+
# Initialize states
|
741
|
+
brainstate.nn.call_all_fns(gru, 'init_state', kwargs={'batch_size': 8})
|
742
|
+
|
743
|
+
# Verify states were created
|
744
|
+
states = gru.states()
|
745
|
+
assert len(states) > 0
|
746
|
+
|
747
|
+
def test_multiple_calls_same_function(self):
|
748
|
+
"""Test calling same function multiple times"""
|
749
|
+
module = CustomMethodModule()
|
750
|
+
|
751
|
+
for i in range(5):
|
752
|
+
brainstate.nn.call_all_fns(module, 'custom_method')
|
753
|
+
|
754
|
+
assert module.call_count == 5
|
755
|
+
|
756
|
+
def test_single_non_tuple_arg(self):
|
757
|
+
"""Test that single non-tuple argument is wrapped"""
|
758
|
+
module = CustomMethodModule()
|
759
|
+
|
760
|
+
brainstate.nn.call_all_fns(module, 'custom_method', args=42)
|
761
|
+
|
762
|
+
assert module.received_args == (42,)
|
763
|
+
|
764
|
+
def test_sequential_module(self):
|
765
|
+
"""Test with Sequential module"""
|
766
|
+
layer1 = CustomMethodModule()
|
767
|
+
layer2 = CustomMethodModule()
|
768
|
+
seq = brainstate.nn.Sequential(layer1, layer2)
|
769
|
+
|
770
|
+
# Sequential itself doesn't have custom_method, so skip it
|
771
|
+
brainstate.nn.call_all_fns(seq, 'custom_method', fn_if_not_exist='pass')
|
772
|
+
|
773
|
+
assert layer1.method_called
|
774
|
+
assert layer2.method_called
|