brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -16,28 +16,759 @@
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
18
 
19
+ import jax.numpy as jnp
20
+ import pytest
21
+
19
22
  import brainstate
20
23
 
21
24
 
22
- class Test_vmap_init_all_states:
25
+ class SimpleTestModule(brainstate.nn.Module):
26
+ """Simple test module with init_state method"""
23
27
 
24
- def test_vmap_init_all_states(self):
25
- gru = brainstate.nn.GRUCell(1, 2)
26
- brainstate.nn.vmap_init_all_states(gru, axis_size=10)
27
- print(gru)
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.state_initialized = False
31
+ self.init_args = None
32
+ self.init_kwargs = None
33
+
34
+ def init_state(self, *args, **kwargs):
35
+ self.state_initialized = True
36
+ self.init_args = args
37
+ self.init_kwargs = kwargs
38
+ self.state = brainstate.State(jnp.zeros(5))
39
+
40
+
41
+ class OrderedTestModule(brainstate.nn.Module):
42
+ """Test module with call_order decorator"""
43
+
44
+ def __init__(self, order_level):
45
+ super().__init__()
46
+ self.order_level = order_level
47
+ self.call_sequence = []
48
+
49
+ @brainstate.nn.call_order(0)
50
+ def init_state(self):
51
+ self.state = brainstate.State(jnp.array([self.order_level]))
28
52
 
29
- def test_vmap_init_all_states_v2(self):
30
- @brainstate.compile.jit
31
- def init():
32
- gru = brainstate.nn.GRUCell(1, 2)
33
- brainstate.nn.vmap_init_all_states(gru, axis_size=10)
34
- print(gru)
35
53
 
36
- init()
54
+ class NestedModule(brainstate.nn.Module):
55
+ """Module with nested submodules"""
56
+
57
+ def __init__(self):
58
+ super().__init__()
59
+ self.submodule1 = SimpleTestModule()
60
+ self.submodule2 = SimpleTestModule()
61
+
62
+ def init_state(self):
63
+ self.state = brainstate.State(jnp.zeros(3))
37
64
 
38
65
 
39
66
  class Test_init_all_states:
40
- def test_init_all_states(self):
67
+ """Comprehensive tests for init_all_states function"""
68
+
69
+ def test_basic_initialization(self):
70
+ """Test basic state initialization"""
71
+ module = SimpleTestModule()
72
+ assert not module.state_initialized
73
+
74
+ brainstate.nn.init_all_states(module)
75
+
76
+ assert module.state_initialized
77
+ assert module.state.value.shape == (5,)
78
+
79
+ def test_with_positional_args(self):
80
+ """Test init_all_states with positional arguments"""
81
+ module = SimpleTestModule()
82
+
83
+ brainstate.nn.init_all_states(module, 1, 2, 3)
84
+
85
+ assert module.state_initialized
86
+ assert module.init_args == (1, 2, 3)
87
+
88
+ def test_with_keyword_args(self):
89
+ """Test init_all_states with keyword arguments"""
90
+ module = SimpleTestModule()
91
+
92
+ brainstate.nn.init_all_states(module, batch_size=10, seq_len=20)
93
+
94
+ assert module.state_initialized
95
+ assert module.init_kwargs == {'batch_size': 10, 'seq_len': 20}
96
+
97
+ def test_with_mixed_args(self):
98
+ """Test init_all_states with both positional and keyword arguments"""
99
+ module = SimpleTestModule()
100
+
101
+ brainstate.nn.init_all_states(module, 42, batch_size=10)
102
+
103
+ assert module.state_initialized
104
+ assert module.init_args == (42,)
105
+ assert module.init_kwargs == {'batch_size': 10}
106
+
107
+ def test_nested_modules(self):
108
+ """Test that init_all_states initializes nested submodules"""
109
+ module = NestedModule()
110
+
111
+ brainstate.nn.init_all_states(module)
112
+
113
+ assert module.submodule1.state_initialized
114
+ assert module.submodule2.state_initialized
115
+ assert hasattr(module, 'state')
116
+
117
+ def test_with_gru_cell(self):
118
+ """Test with real GRUCell module"""
41
119
  gru = brainstate.nn.GRUCell(1, 2)
120
+
42
121
  brainstate.nn.init_all_states(gru, batch_size=10)
43
- print(gru)
122
+
123
+ # Check that states were created
124
+ state_dict = gru.states()
125
+ assert len(state_dict) > 0
126
+
127
+ def test_sequential_module(self):
128
+ """Test with Sequential module containing multiple layers"""
129
+ seq = brainstate.nn.Sequential(
130
+ brainstate.nn.Linear(10, 20),
131
+ brainstate.nn.Dropout(0.5)
132
+ )
133
+
134
+ brainstate.nn.init_all_states(seq)
135
+
136
+ # Check that Linear layer has weight and bias
137
+ state_dict = seq.states()
138
+ assert len(state_dict) > 0
139
+
140
+ def test_node_to_exclude(self):
141
+ """Test excluding specific nodes from initialization"""
142
+ module = NestedModule()
143
+
144
+ # Exclude submodule1 by type matching - simpler and more reliable
145
+ brainstate.nn.init_all_states(
146
+ module,
147
+ node_to_exclude=NestedModule # Exclude the parent, only init children
148
+ )
149
+
150
+ # Parent should not be initialized, but children should be
151
+ assert not hasattr(module, 'state') or module.state is None or not hasattr(module.state, 'value')
152
+ assert module.submodule1.state_initialized
153
+ assert module.submodule2.state_initialized
154
+
155
+ def test_with_call_order(self):
156
+ """Test that call_order is respected during initialization"""
157
+
158
+ class OrderedModule(brainstate.nn.Module):
159
+ def __init__(self):
160
+ super().__init__()
161
+ self.execution_order = []
162
+
163
+ @brainstate.nn.call_order(1)
164
+ def init_state(self):
165
+ self.execution_order.append('parent')
166
+
167
+ class ChildModule(brainstate.nn.Module):
168
+ def __init__(self, parent_module):
169
+ super().__init__()
170
+ self.parent = parent_module
171
+
172
+ @brainstate.nn.call_order(0)
173
+ def init_state(self):
174
+ self.parent.execution_order.append('child')
175
+
176
+ parent = OrderedModule()
177
+ child = ChildModule(parent)
178
+ parent.child_module = child
179
+
180
+ brainstate.nn.init_all_states(parent)
181
+
182
+ # Child (order 0) should execute before parent (order 1)
183
+ assert parent.execution_order == ['child', 'parent']
184
+
185
+
186
+ class ResetTestModule(brainstate.nn.Module):
187
+ """Test module with both init_state and reset_state methods"""
188
+
189
+ def __init__(self):
190
+ super().__init__()
191
+ self.state_initialized = False
192
+ self.state_reset = False
193
+ self.reset_args = None
194
+ self.reset_kwargs = None
195
+
196
+ def init_state(self, *args, **kwargs):
197
+ self.state_initialized = True
198
+ self.state_reset = False
199
+ self.state = brainstate.State(jnp.ones(5))
200
+
201
+ def reset_state(self, *args, **kwargs):
202
+ self.state_reset = True
203
+ self.reset_args = args
204
+ self.reset_kwargs = kwargs
205
+ if hasattr(self, 'state'):
206
+ self.state.value = jnp.zeros(5)
207
+
208
+
209
+ class ResetOrderedModule(brainstate.nn.Module):
210
+ """Test module with call_order on reset_state"""
211
+
212
+ def __init__(self, order_level, execution_tracker):
213
+ super().__init__()
214
+ self.order_level = order_level
215
+ self.execution_tracker = execution_tracker
216
+
217
+ def init_state(self):
218
+ self.state = brainstate.State(jnp.ones(3))
219
+
220
+ @brainstate.nn.call_order(0)
221
+ def reset_state(self):
222
+ self.execution_tracker.append(f'order_{self.order_level}')
223
+ self.state.value = jnp.zeros(3)
224
+
225
+
226
+ class NestedResetModule(brainstate.nn.Module):
227
+ """Module with nested submodules that have reset_state"""
228
+
229
+ def __init__(self):
230
+ super().__init__()
231
+ self.submodule1 = ResetTestModule()
232
+ self.submodule2 = ResetTestModule()
233
+
234
+ def init_state(self):
235
+ self.state = brainstate.State(jnp.ones(3))
236
+
237
+ def reset_state(self):
238
+ self.state.value = jnp.zeros(3)
239
+
240
+
241
+ class Test_reset_all_states:
242
+ """Comprehensive tests for reset_all_states function"""
243
+
244
+ def test_basic_reset(self):
245
+ """Test basic state reset"""
246
+ module = ResetTestModule()
247
+ brainstate.nn.init_all_states(module)
248
+
249
+ assert module.state_initialized
250
+ assert not module.state_reset
251
+ assert jnp.allclose(module.state.value, jnp.ones(5))
252
+
253
+ brainstate.nn.reset_all_states(module)
254
+
255
+ assert module.state_reset
256
+ assert jnp.allclose(module.state.value, jnp.zeros(5))
257
+
258
+ def test_with_positional_args(self):
259
+ """Test reset_all_states with positional arguments"""
260
+ module = ResetTestModule()
261
+ brainstate.nn.init_all_states(module)
262
+
263
+ brainstate.nn.reset_all_states(module, 1, 2, 3)
264
+
265
+ assert module.state_reset
266
+ assert module.reset_args == (1, 2, 3)
267
+
268
+ def test_with_keyword_args(self):
269
+ """Test reset_all_states with keyword arguments"""
270
+ module = ResetTestModule()
271
+ brainstate.nn.init_all_states(module)
272
+
273
+ brainstate.nn.reset_all_states(module, batch_size=10, seq_len=20)
274
+
275
+ assert module.state_reset
276
+ assert module.reset_kwargs == {'batch_size': 10, 'seq_len': 20}
277
+
278
+ def test_with_mixed_args(self):
279
+ """Test reset_all_states with both positional and keyword arguments"""
280
+ module = ResetTestModule()
281
+ brainstate.nn.init_all_states(module)
282
+
283
+ brainstate.nn.reset_all_states(module, 42, batch_size=10)
284
+
285
+ assert module.state_reset
286
+ assert module.reset_args == (42,)
287
+ assert module.reset_kwargs == {'batch_size': 10}
288
+
289
+ def test_nested_modules(self):
290
+ """Test that reset_all_states resets nested submodules"""
291
+ module = NestedResetModule()
292
+ brainstate.nn.init_all_states(module)
293
+
294
+ # Verify initial state
295
+ assert jnp.allclose(module.state.value, jnp.ones(3))
296
+ assert jnp.allclose(module.submodule1.state.value, jnp.ones(5))
297
+ assert jnp.allclose(module.submodule2.state.value, jnp.ones(5))
298
+
299
+ brainstate.nn.reset_all_states(module)
300
+
301
+ # Verify all states were reset
302
+ assert jnp.allclose(module.state.value, jnp.zeros(3))
303
+ assert module.submodule1.state_reset
304
+ assert module.submodule2.state_reset
305
+ assert jnp.allclose(module.submodule1.state.value, jnp.zeros(5))
306
+ assert jnp.allclose(module.submodule2.state.value, jnp.zeros(5))
307
+
308
+ def test_with_gru_cell(self):
309
+ """Test reset with real GRUCell module"""
310
+ gru = brainstate.nn.GRUCell(5, 10)
311
+ brainstate.nn.init_all_states(gru, batch_size=8)
312
+
313
+ # Get initial state
314
+ initial_states = {k: v.value.copy() for k, v in gru.states().items()
315
+ if hasattr(v.value, 'copy') and not isinstance(v.value, dict)}
316
+
317
+ # Reset state
318
+ brainstate.nn.reset_all_states(gru, batch_size=8)
319
+
320
+ # Verify state was reset (should be zeros for hidden state)
321
+ for key in initial_states:
322
+ current_val = gru.states()[key].value
323
+ if not isinstance(current_val, dict):
324
+ # Hidden state should be reset to zeros
325
+ if 'h' in str(key):
326
+ assert jnp.allclose(current_val, jnp.zeros_like(current_val))
327
+
328
+ def test_sequential_reset(self):
329
+ """Test reset with Sequential module"""
330
+
331
+ # Create a simple network with stateful components
332
+ class StatefulLayer(brainstate.nn.Module):
333
+ def __init__(self):
334
+ super().__init__()
335
+ self.reset_called = False
336
+
337
+ def init_state(self):
338
+ self.state = brainstate.State(jnp.ones(5))
339
+
340
+ def reset_state(self):
341
+ self.reset_called = True
342
+ self.state.value = jnp.zeros(5)
343
+
344
+ layer1 = StatefulLayer()
345
+ layer2 = StatefulLayer()
346
+ seq = brainstate.nn.Sequential(layer1, layer2)
347
+
348
+ brainstate.nn.init_all_states(seq)
349
+ brainstate.nn.reset_all_states(seq)
350
+
351
+ assert layer1.reset_called
352
+ assert layer2.reset_called
353
+
354
+ def test_node_to_exclude(self):
355
+ """Test excluding specific nodes from reset"""
356
+ module = NestedResetModule()
357
+ brainstate.nn.init_all_states(module)
358
+
359
+ # Exclude the parent module from reset
360
+ brainstate.nn.reset_all_states(
361
+ module,
362
+ node_to_exclude=NestedResetModule
363
+ )
364
+
365
+ # Parent should not be reset, but children should be
366
+ assert jnp.allclose(module.state.value, jnp.ones(3)) # Not reset
367
+ assert module.submodule1.state_reset # Reset
368
+ assert module.submodule2.state_reset # Reset
369
+
370
+ def test_with_call_order(self):
371
+ """Test that call_order is respected during reset"""
372
+ execution_tracker = []
373
+
374
+ class OrderedParent(brainstate.nn.Module):
375
+ def __init__(self):
376
+ super().__init__()
377
+ self.child1 = ResetOrderedModule(1, execution_tracker)
378
+ self.child2 = ResetOrderedModule(2, execution_tracker)
379
+
380
+ def init_state(self):
381
+ pass
382
+
383
+ parent = OrderedParent()
384
+ brainstate.nn.init_all_states(parent)
385
+
386
+ execution_tracker.clear()
387
+ brainstate.nn.reset_all_states(parent)
388
+
389
+ # Both should execute (order 0), check that reset was called
390
+ assert len(execution_tracker) == 2
391
+
392
+ def test_multiple_resets(self):
393
+ """Test calling reset_all_states multiple times"""
394
+ module = ResetTestModule()
395
+ brainstate.nn.init_all_states(module)
396
+
397
+ for i in range(3):
398
+ brainstate.nn.reset_all_states(module)
399
+ assert module.state_reset
400
+ assert jnp.allclose(module.state.value, jnp.zeros(5))
401
+
402
+ def test_reset_without_init(self):
403
+ """Test that reset works even if init wasn't called explicitly"""
404
+ gru = brainstate.nn.GRUCell(5, 10)
405
+
406
+ # Initialize first
407
+ brainstate.nn.init_all_states(gru, batch_size=8)
408
+
409
+ # Reset should work
410
+ brainstate.nn.reset_all_states(gru, batch_size=8)
411
+
412
+ # Verify it didn't crash
413
+ states = gru.states()
414
+ assert len(states) > 0
415
+
416
+
417
+ class CustomMethodModule(brainstate.nn.Module):
418
+ """Test module with custom methods for call_all_fns testing"""
419
+
420
+ def __init__(self):
421
+ super().__init__()
422
+ self.method_called = False
423
+ self.call_count = 0
424
+ self.received_args = None
425
+ self.received_kwargs = None
426
+
427
+ def custom_method(self, *args, **kwargs):
428
+ self.method_called = True
429
+ self.call_count += 1
430
+ self.received_args = args
431
+ self.received_kwargs = kwargs
432
+
433
+ def another_method(self):
434
+ self.call_count += 10
435
+
436
+
437
+ class OrderedCallModule(brainstate.nn.Module):
438
+ """Test module with ordered methods"""
439
+
440
+ def __init__(self, execution_tracker):
441
+ super().__init__()
442
+ self.execution_tracker = execution_tracker
443
+
444
+ @brainstate.nn.call_order(0)
445
+ def ordered_method_0(self):
446
+ self.execution_tracker.append('order_0')
447
+
448
+ @brainstate.nn.call_order(1)
449
+ def ordered_method_1(self):
450
+ self.execution_tracker.append('order_1')
451
+
452
+ @brainstate.nn.call_order(2)
453
+ def ordered_method_2(self):
454
+ self.execution_tracker.append('order_2')
455
+
456
+ def unordered_method(self):
457
+ self.execution_tracker.append('unordered')
458
+
459
+
460
+ class NestedCallModule(brainstate.nn.Module):
461
+ """Module with nested submodules for call testing"""
462
+
463
+ def __init__(self):
464
+ super().__init__()
465
+ self.child1 = CustomMethodModule()
466
+ self.child2 = CustomMethodModule()
467
+ self.method_called = False
468
+
469
+ def custom_method(self, *args, **kwargs):
470
+ self.method_called = True
471
+
472
+
473
+ class Test_call_order:
474
+ """Comprehensive tests for call_order decorator"""
475
+
476
+ def test_basic_call_order(self):
477
+ """Test basic call_order decorator"""
478
+ execution_tracker = []
479
+
480
+ class TestModule(brainstate.nn.Module):
481
+ @brainstate.nn.call_order(0)
482
+ def method(self):
483
+ execution_tracker.append('executed')
484
+
485
+ module = TestModule()
486
+ assert hasattr(module.method, 'call_order')
487
+ assert module.method.call_order == 0
488
+
489
+ def test_different_order_levels(self):
490
+ """Test different order levels"""
491
+ for level in [0, 1, 5, 9]:
492
+ class TestModule(brainstate.nn.Module):
493
+ @brainstate.nn.call_order(level)
494
+ def method(self):
495
+ pass
496
+
497
+ module = TestModule()
498
+ assert module.method.call_order == level
499
+
500
+ def test_order_boundary_validation(self):
501
+ """Test that order level boundary validation works"""
502
+ # Valid levels (0 to MAX_ORDER-1)
503
+ for level in range(brainstate.nn._collective_ops.MAX_ORDER):
504
+ @brainstate.nn.call_order(level)
505
+ def valid_method():
506
+ pass
507
+
508
+ assert valid_method.call_order == level
509
+
510
+ # Invalid levels
511
+ with pytest.raises(ValueError, match="must be an integer"):
512
+ @brainstate.nn.call_order(-1)
513
+ def invalid_method1():
514
+ pass
515
+
516
+ with pytest.raises(ValueError, match="must be an integer"):
517
+ @brainstate.nn.call_order(brainstate.nn._collective_ops.MAX_ORDER)
518
+ def invalid_method2():
519
+ pass
520
+
521
+ def test_disable_boundary_check(self):
522
+ """Test disabling boundary check"""
523
+
524
+ @brainstate.nn.call_order(100, check_order_boundary=False)
525
+ def method():
526
+ pass
527
+
528
+ assert method.call_order == 100
529
+
530
+ @brainstate.nn.call_order(-5, check_order_boundary=False)
531
+ def method2():
532
+ pass
533
+
534
+ assert method2.call_order == -5
535
+
536
+ def test_order_preserved_on_methods(self):
537
+ """Test that call_order is preserved on instance methods"""
538
+ execution_tracker = []
539
+ module = OrderedCallModule(execution_tracker)
540
+
541
+ assert module.ordered_method_0.call_order == 0
542
+ assert module.ordered_method_1.call_order == 1
543
+ assert module.ordered_method_2.call_order == 2
544
+ assert not hasattr(module.unordered_method, 'call_order')
545
+
546
+ def test_multiple_decorators(self):
547
+ """Test applying call_order to multiple methods"""
548
+ execution_tracker = []
549
+
550
+ class MultiMethodModule(brainstate.nn.Module):
551
+ @brainstate.nn.call_order(2)
552
+ def method_a(self):
553
+ execution_tracker.append('a')
554
+
555
+ @brainstate.nn.call_order(0)
556
+ def method_b(self):
557
+ execution_tracker.append('b')
558
+
559
+ @brainstate.nn.call_order(1)
560
+ def method_c(self):
561
+ execution_tracker.append('c')
562
+
563
+ module = MultiMethodModule()
564
+ assert module.method_a.call_order == 2
565
+ assert module.method_b.call_order == 0
566
+ assert module.method_c.call_order == 1
567
+
568
+
569
+ class Test_call_all_fns:
570
+ """Comprehensive tests for call_all_fns function"""
571
+
572
+ def test_basic_function_call(self):
573
+ """Test basic function calling"""
574
+ module = CustomMethodModule()
575
+
576
+ assert not module.method_called
577
+ brainstate.nn.call_all_fns(module, 'custom_method')
578
+ assert module.method_called
579
+ assert module.call_count == 1
580
+
581
+ def test_with_positional_args(self):
582
+ """Test call_all_fns with positional arguments"""
583
+ module = CustomMethodModule()
584
+
585
+ brainstate.nn.call_all_fns(module, 'custom_method', (1, 2, 3))
586
+
587
+ assert module.method_called
588
+ assert module.received_args == (1, 2, 3)
589
+
590
+ def test_with_keyword_args(self):
591
+ """Test call_all_fns with keyword arguments"""
592
+ module = CustomMethodModule()
593
+
594
+ brainstate.nn.call_all_fns(module, 'custom_method', kwargs={'foo': 'bar', 'baz': 42})
595
+
596
+ assert module.method_called
597
+ assert module.received_kwargs == {'foo': 'bar', 'baz': 42}
598
+
599
+ def test_with_mixed_args(self):
600
+ """Test call_all_fns with both positional and keyword arguments"""
601
+ module = CustomMethodModule()
602
+
603
+ brainstate.nn.call_all_fns(
604
+ module,
605
+ 'custom_method',
606
+ args=(1, 2),
607
+ kwargs={'key': 'value'}
608
+ )
609
+
610
+ assert module.method_called
611
+ assert module.received_args == (1, 2)
612
+ assert module.received_kwargs == {'key': 'value'}
613
+
614
+ def test_nested_modules(self):
615
+ """Test that call_all_fns calls methods on nested modules"""
616
+ module = NestedCallModule()
617
+
618
+ brainstate.nn.call_all_fns(module, 'custom_method')
619
+
620
+ assert module.method_called
621
+ assert module.child1.method_called
622
+ assert module.child2.method_called
623
+
624
+ def test_call_order_respected(self):
625
+ """Test that call_order is respected"""
626
+ execution_tracker = []
627
+
628
+ class ParentModule(brainstate.nn.Module):
629
+ def __init__(self):
630
+ super().__init__()
631
+ self.child1 = OrderedCallModule(execution_tracker)
632
+ self.child2 = OrderedCallModule(execution_tracker)
633
+
634
+ module = ParentModule()
635
+
636
+ # Call ordered_method_1 on all modules (parent doesn't have it, so skip)
637
+ brainstate.nn.call_all_fns(module, 'ordered_method_1', fn_if_not_exist='pass')
638
+
639
+ # Should be called on both children (both have order 1)
640
+ assert execution_tracker.count('order_1') == 2
641
+
642
+ def test_execution_order_with_mixed_decorators(self):
643
+ """Test execution order with both ordered and unordered methods"""
644
+ execution_tracker = []
645
+
646
+ class MixedModule(brainstate.nn.Module):
647
+ def __init__(self):
648
+ super().__init__()
649
+ self.ordered = OrderedCallModule(execution_tracker)
650
+
651
+ def unordered_method(self):
652
+ execution_tracker.append('parent_unordered')
653
+
654
+ module = MixedModule()
655
+
656
+ # Call unordered_method - parent has no decorator, child has no decorator
657
+ execution_tracker.clear()
658
+ brainstate.nn.call_all_fns(module, 'unordered_method')
659
+
660
+ # Both should be called (unordered methods execute first)
661
+ assert 'parent_unordered' in execution_tracker
662
+ assert 'unordered' in execution_tracker
663
+
664
+ def test_node_to_exclude(self):
665
+ """Test excluding specific nodes"""
666
+ module = NestedCallModule()
667
+
668
+ # Exclude the parent module
669
+ brainstate.nn.call_all_fns(
670
+ module,
671
+ 'custom_method',
672
+ node_to_exclude=NestedCallModule
673
+ )
674
+
675
+ # Parent should not be called, but children should be
676
+ assert not module.method_called
677
+ assert module.child1.method_called
678
+ assert module.child2.method_called
679
+
680
+ def test_fn_if_not_exist_raise(self):
681
+ """Test fn_if_not_exist='raise' behavior"""
682
+ module = CustomMethodModule()
683
+
684
+ with pytest.raises(AttributeError, match="does not have method"):
685
+ brainstate.nn.call_all_fns(module, 'nonexistent_method', fn_if_not_exist='raise')
686
+
687
+ def test_fn_if_not_exist_pass(self):
688
+ """Test fn_if_not_exist='pass' behavior"""
689
+ module = CustomMethodModule()
690
+
691
+ # Should not raise error
692
+ brainstate.nn.call_all_fns(module, 'nonexistent_method', fn_if_not_exist='pass')
693
+
694
+ def test_fn_if_not_exist_none(self):
695
+ """Test fn_if_not_exist='none' behavior"""
696
+ module = CustomMethodModule()
697
+
698
+ # Should not raise error
699
+ brainstate.nn.call_all_fns(module, 'nonexistent_method', fn_if_not_exist='none')
700
+
701
+ def test_fn_if_not_exist_warn(self):
702
+ """Test fn_if_not_exist='warn' behavior"""
703
+ module = CustomMethodModule()
704
+
705
+ # Should issue warning but not raise
706
+ with pytest.warns(UserWarning, match="does not have method"):
707
+ brainstate.nn.call_all_fns(module, 'nonexistent_method', fn_if_not_exist='warn')
708
+
709
+ def test_invalid_fn_name_type(self):
710
+ """Test that invalid fn_name type raises error"""
711
+ module = CustomMethodModule()
712
+
713
+ with pytest.raises(TypeError, match="fn_name must be a string"):
714
+ brainstate.nn.call_all_fns(module, 123)
715
+
716
+ def test_invalid_kwargs_type(self):
717
+ """Test that invalid kwargs type raises error"""
718
+ module = CustomMethodModule()
719
+
720
+ with pytest.raises(TypeError, match="kwargs must be a mapping"):
721
+ brainstate.nn.call_all_fns(module, 'custom_method', kwargs=[1, 2, 3])
722
+
723
+ def test_non_callable_attribute(self):
724
+ """Test that non-callable attributes raise error"""
725
+
726
+ class ModuleWithAttribute(brainstate.nn.Module):
727
+ def __init__(self):
728
+ super().__init__()
729
+ self.my_attr = "not callable"
730
+
731
+ module = ModuleWithAttribute()
732
+
733
+ with pytest.raises(TypeError, match="must be callable"):
734
+ brainstate.nn.call_all_fns(module, 'my_attr')
735
+
736
+ def test_with_gru_cell(self):
737
+ """Test with real GRU cell"""
738
+ gru = brainstate.nn.GRUCell(5, 10)
739
+
740
+ # Initialize states
741
+ brainstate.nn.call_all_fns(gru, 'init_state', kwargs={'batch_size': 8})
742
+
743
+ # Verify states were created
744
+ states = gru.states()
745
+ assert len(states) > 0
746
+
747
+ def test_multiple_calls_same_function(self):
748
+ """Test calling same function multiple times"""
749
+ module = CustomMethodModule()
750
+
751
+ for i in range(5):
752
+ brainstate.nn.call_all_fns(module, 'custom_method')
753
+
754
+ assert module.call_count == 5
755
+
756
+ def test_single_non_tuple_arg(self):
757
+ """Test that single non-tuple argument is wrapped"""
758
+ module = CustomMethodModule()
759
+
760
+ brainstate.nn.call_all_fns(module, 'custom_method', args=42)
761
+
762
+ assert module.received_args == (42,)
763
+
764
+ def test_sequential_module(self):
765
+ """Test with Sequential module"""
766
+ layer1 = CustomMethodModule()
767
+ layer2 = CustomMethodModule()
768
+ seq = brainstate.nn.Sequential(layer1, layer2)
769
+
770
+ # Sequential itself doesn't have custom_method, so skip it
771
+ brainstate.nn.call_all_fns(seq, 'custom_method', fn_if_not_exist='pass')
772
+
773
+ assert layer1.method_called
774
+ assert layer2.method_called