brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,774 +1,774 @@
1
- # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
-
19
- import jax.numpy as jnp
20
- import pytest
21
-
22
- import brainstate
23
-
24
-
25
- class SimpleTestModule(brainstate.nn.Module):
26
- """Simple test module with init_state method"""
27
-
28
- def __init__(self):
29
- super().__init__()
30
- self.state_initialized = False
31
- self.init_args = None
32
- self.init_kwargs = None
33
-
34
- def init_state(self, *args, **kwargs):
35
- self.state_initialized = True
36
- self.init_args = args
37
- self.init_kwargs = kwargs
38
- self.state = brainstate.State(jnp.zeros(5))
39
-
40
-
41
- class OrderedTestModule(brainstate.nn.Module):
42
- """Test module with call_order decorator"""
43
-
44
- def __init__(self, order_level):
45
- super().__init__()
46
- self.order_level = order_level
47
- self.call_sequence = []
48
-
49
- @brainstate.nn.call_order(0)
50
- def init_state(self):
51
- self.state = brainstate.State(jnp.array([self.order_level]))
52
-
53
-
54
- class NestedModule(brainstate.nn.Module):
55
- """Module with nested submodules"""
56
-
57
- def __init__(self):
58
- super().__init__()
59
- self.submodule1 = SimpleTestModule()
60
- self.submodule2 = SimpleTestModule()
61
-
62
- def init_state(self):
63
- self.state = brainstate.State(jnp.zeros(3))
64
-
65
-
66
- class Test_init_all_states:
67
- """Comprehensive tests for init_all_states function"""
68
-
69
- def test_basic_initialization(self):
70
- """Test basic state initialization"""
71
- module = SimpleTestModule()
72
- assert not module.state_initialized
73
-
74
- brainstate.nn.init_all_states(module)
75
-
76
- assert module.state_initialized
77
- assert module.state.value.shape == (5,)
78
-
79
- def test_with_positional_args(self):
80
- """Test init_all_states with positional arguments"""
81
- module = SimpleTestModule()
82
-
83
- brainstate.nn.init_all_states(module, 1, 2, 3)
84
-
85
- assert module.state_initialized
86
- assert module.init_args == (1, 2, 3)
87
-
88
- def test_with_keyword_args(self):
89
- """Test init_all_states with keyword arguments"""
90
- module = SimpleTestModule()
91
-
92
- brainstate.nn.init_all_states(module, batch_size=10, seq_len=20)
93
-
94
- assert module.state_initialized
95
- assert module.init_kwargs == {'batch_size': 10, 'seq_len': 20}
96
-
97
- def test_with_mixed_args(self):
98
- """Test init_all_states with both positional and keyword arguments"""
99
- module = SimpleTestModule()
100
-
101
- brainstate.nn.init_all_states(module, 42, batch_size=10)
102
-
103
- assert module.state_initialized
104
- assert module.init_args == (42,)
105
- assert module.init_kwargs == {'batch_size': 10}
106
-
107
- def test_nested_modules(self):
108
- """Test that init_all_states initializes nested submodules"""
109
- module = NestedModule()
110
-
111
- brainstate.nn.init_all_states(module)
112
-
113
- assert module.submodule1.state_initialized
114
- assert module.submodule2.state_initialized
115
- assert hasattr(module, 'state')
116
-
117
- def test_with_gru_cell(self):
118
- """Test with real GRUCell module"""
119
- gru = brainstate.nn.GRUCell(1, 2)
120
-
121
- brainstate.nn.init_all_states(gru, batch_size=10)
122
-
123
- # Check that states were created
124
- state_dict = gru.states()
125
- assert len(state_dict) > 0
126
-
127
- def test_sequential_module(self):
128
- """Test with Sequential module containing multiple layers"""
129
- seq = brainstate.nn.Sequential(
130
- brainstate.nn.Linear(10, 20),
131
- brainstate.nn.Dropout(0.5)
132
- )
133
-
134
- brainstate.nn.init_all_states(seq)
135
-
136
- # Check that Linear layer has weight and bias
137
- state_dict = seq.states()
138
- assert len(state_dict) > 0
139
-
140
- def test_node_to_exclude(self):
141
- """Test excluding specific nodes from initialization"""
142
- module = NestedModule()
143
-
144
- # Exclude submodule1 by type matching - simpler and more reliable
145
- brainstate.nn.init_all_states(
146
- module,
147
- node_to_exclude=NestedModule # Exclude the parent, only init children
148
- )
149
-
150
- # Parent should not be initialized, but children should be
151
- assert not hasattr(module, 'state') or module.state is None or not hasattr(module.state, 'value')
152
- assert module.submodule1.state_initialized
153
- assert module.submodule2.state_initialized
154
-
155
- def test_with_call_order(self):
156
- """Test that call_order is respected during initialization"""
157
-
158
- class OrderedModule(brainstate.nn.Module):
159
- def __init__(self):
160
- super().__init__()
161
- self.execution_order = []
162
-
163
- @brainstate.nn.call_order(1)
164
- def init_state(self):
165
- self.execution_order.append('parent')
166
-
167
- class ChildModule(brainstate.nn.Module):
168
- def __init__(self, parent_module):
169
- super().__init__()
170
- self.parent = parent_module
171
-
172
- @brainstate.nn.call_order(0)
173
- def init_state(self):
174
- self.parent.execution_order.append('child')
175
-
176
- parent = OrderedModule()
177
- child = ChildModule(parent)
178
- parent.child_module = child
179
-
180
- brainstate.nn.init_all_states(parent)
181
-
182
- # Child (order 0) should execute before parent (order 1)
183
- assert parent.execution_order == ['child', 'parent']
184
-
185
-
186
- class ResetTestModule(brainstate.nn.Module):
187
- """Test module with both init_state and reset_state methods"""
188
-
189
- def __init__(self):
190
- super().__init__()
191
- self.state_initialized = False
192
- self.state_reset = False
193
- self.reset_args = None
194
- self.reset_kwargs = None
195
-
196
- def init_state(self, *args, **kwargs):
197
- self.state_initialized = True
198
- self.state_reset = False
199
- self.state = brainstate.State(jnp.ones(5))
200
-
201
- def reset_state(self, *args, **kwargs):
202
- self.state_reset = True
203
- self.reset_args = args
204
- self.reset_kwargs = kwargs
205
- if hasattr(self, 'state'):
206
- self.state.value = jnp.zeros(5)
207
-
208
-
209
- class ResetOrderedModule(brainstate.nn.Module):
210
- """Test module with call_order on reset_state"""
211
-
212
- def __init__(self, order_level, execution_tracker):
213
- super().__init__()
214
- self.order_level = order_level
215
- self.execution_tracker = execution_tracker
216
-
217
- def init_state(self):
218
- self.state = brainstate.State(jnp.ones(3))
219
-
220
- @brainstate.nn.call_order(0)
221
- def reset_state(self):
222
- self.execution_tracker.append(f'order_{self.order_level}')
223
- self.state.value = jnp.zeros(3)
224
-
225
-
226
- class NestedResetModule(brainstate.nn.Module):
227
- """Module with nested submodules that have reset_state"""
228
-
229
- def __init__(self):
230
- super().__init__()
231
- self.submodule1 = ResetTestModule()
232
- self.submodule2 = ResetTestModule()
233
-
234
- def init_state(self):
235
- self.state = brainstate.State(jnp.ones(3))
236
-
237
- def reset_state(self):
238
- self.state.value = jnp.zeros(3)
239
-
240
-
241
- class Test_reset_all_states:
242
- """Comprehensive tests for reset_all_states function"""
243
-
244
- def test_basic_reset(self):
245
- """Test basic state reset"""
246
- module = ResetTestModule()
247
- brainstate.nn.init_all_states(module)
248
-
249
- assert module.state_initialized
250
- assert not module.state_reset
251
- assert jnp.allclose(module.state.value, jnp.ones(5))
252
-
253
- brainstate.nn.reset_all_states(module)
254
-
255
- assert module.state_reset
256
- assert jnp.allclose(module.state.value, jnp.zeros(5))
257
-
258
- def test_with_positional_args(self):
259
- """Test reset_all_states with positional arguments"""
260
- module = ResetTestModule()
261
- brainstate.nn.init_all_states(module)
262
-
263
- brainstate.nn.reset_all_states(module, 1, 2, 3)
264
-
265
- assert module.state_reset
266
- assert module.reset_args == (1, 2, 3)
267
-
268
- def test_with_keyword_args(self):
269
- """Test reset_all_states with keyword arguments"""
270
- module = ResetTestModule()
271
- brainstate.nn.init_all_states(module)
272
-
273
- brainstate.nn.reset_all_states(module, batch_size=10, seq_len=20)
274
-
275
- assert module.state_reset
276
- assert module.reset_kwargs == {'batch_size': 10, 'seq_len': 20}
277
-
278
- def test_with_mixed_args(self):
279
- """Test reset_all_states with both positional and keyword arguments"""
280
- module = ResetTestModule()
281
- brainstate.nn.init_all_states(module)
282
-
283
- brainstate.nn.reset_all_states(module, 42, batch_size=10)
284
-
285
- assert module.state_reset
286
- assert module.reset_args == (42,)
287
- assert module.reset_kwargs == {'batch_size': 10}
288
-
289
- def test_nested_modules(self):
290
- """Test that reset_all_states resets nested submodules"""
291
- module = NestedResetModule()
292
- brainstate.nn.init_all_states(module)
293
-
294
- # Verify initial state
295
- assert jnp.allclose(module.state.value, jnp.ones(3))
296
- assert jnp.allclose(module.submodule1.state.value, jnp.ones(5))
297
- assert jnp.allclose(module.submodule2.state.value, jnp.ones(5))
298
-
299
- brainstate.nn.reset_all_states(module)
300
-
301
- # Verify all states were reset
302
- assert jnp.allclose(module.state.value, jnp.zeros(3))
303
- assert module.submodule1.state_reset
304
- assert module.submodule2.state_reset
305
- assert jnp.allclose(module.submodule1.state.value, jnp.zeros(5))
306
- assert jnp.allclose(module.submodule2.state.value, jnp.zeros(5))
307
-
308
- def test_with_gru_cell(self):
309
- """Test reset with real GRUCell module"""
310
- gru = brainstate.nn.GRUCell(5, 10)
311
- brainstate.nn.init_all_states(gru, batch_size=8)
312
-
313
- # Get initial state
314
- initial_states = {k: v.value.copy() for k, v in gru.states().items()
315
- if hasattr(v.value, 'copy') and not isinstance(v.value, dict)}
316
-
317
- # Reset state
318
- brainstate.nn.reset_all_states(gru, batch_size=8)
319
-
320
- # Verify state was reset (should be zeros for hidden state)
321
- for key in initial_states:
322
- current_val = gru.states()[key].value
323
- if not isinstance(current_val, dict):
324
- # Hidden state should be reset to zeros
325
- if 'h' in str(key):
326
- assert jnp.allclose(current_val, jnp.zeros_like(current_val))
327
-
328
- def test_sequential_reset(self):
329
- """Test reset with Sequential module"""
330
-
331
- # Create a simple network with stateful components
332
- class StatefulLayer(brainstate.nn.Module):
333
- def __init__(self):
334
- super().__init__()
335
- self.reset_called = False
336
-
337
- def init_state(self):
338
- self.state = brainstate.State(jnp.ones(5))
339
-
340
- def reset_state(self):
341
- self.reset_called = True
342
- self.state.value = jnp.zeros(5)
343
-
344
- layer1 = StatefulLayer()
345
- layer2 = StatefulLayer()
346
- seq = brainstate.nn.Sequential(layer1, layer2)
347
-
348
- brainstate.nn.init_all_states(seq)
349
- brainstate.nn.reset_all_states(seq)
350
-
351
- assert layer1.reset_called
352
- assert layer2.reset_called
353
-
354
- def test_node_to_exclude(self):
355
- """Test excluding specific nodes from reset"""
356
- module = NestedResetModule()
357
- brainstate.nn.init_all_states(module)
358
-
359
- # Exclude the parent module from reset
360
- brainstate.nn.reset_all_states(
361
- module,
362
- node_to_exclude=NestedResetModule
363
- )
364
-
365
- # Parent should not be reset, but children should be
366
- assert jnp.allclose(module.state.value, jnp.ones(3)) # Not reset
367
- assert module.submodule1.state_reset # Reset
368
- assert module.submodule2.state_reset # Reset
369
-
370
- def test_with_call_order(self):
371
- """Test that call_order is respected during reset"""
372
- execution_tracker = []
373
-
374
- class OrderedParent(brainstate.nn.Module):
375
- def __init__(self):
376
- super().__init__()
377
- self.child1 = ResetOrderedModule(1, execution_tracker)
378
- self.child2 = ResetOrderedModule(2, execution_tracker)
379
-
380
- def init_state(self):
381
- pass
382
-
383
- parent = OrderedParent()
384
- brainstate.nn.init_all_states(parent)
385
-
386
- execution_tracker.clear()
387
- brainstate.nn.reset_all_states(parent)
388
-
389
- # Both should execute (order 0), check that reset was called
390
- assert len(execution_tracker) == 2
391
-
392
- def test_multiple_resets(self):
393
- """Test calling reset_all_states multiple times"""
394
- module = ResetTestModule()
395
- brainstate.nn.init_all_states(module)
396
-
397
- for i in range(3):
398
- brainstate.nn.reset_all_states(module)
399
- assert module.state_reset
400
- assert jnp.allclose(module.state.value, jnp.zeros(5))
401
-
402
- def test_reset_without_init(self):
403
- """Test that reset works even if init wasn't called explicitly"""
404
- gru = brainstate.nn.GRUCell(5, 10)
405
-
406
- # Initialize first
407
- brainstate.nn.init_all_states(gru, batch_size=8)
408
-
409
- # Reset should work
410
- brainstate.nn.reset_all_states(gru, batch_size=8)
411
-
412
- # Verify it didn't crash
413
- states = gru.states()
414
- assert len(states) > 0
415
-
416
-
417
- class CustomMethodModule(brainstate.nn.Module):
418
- """Test module with custom methods for call_all_fns testing"""
419
-
420
- def __init__(self):
421
- super().__init__()
422
- self.method_called = False
423
- self.call_count = 0
424
- self.received_args = None
425
- self.received_kwargs = None
426
-
427
- def custom_method(self, *args, **kwargs):
428
- self.method_called = True
429
- self.call_count += 1
430
- self.received_args = args
431
- self.received_kwargs = kwargs
432
-
433
- def another_method(self):
434
- self.call_count += 10
435
-
436
-
437
- class OrderedCallModule(brainstate.nn.Module):
438
- """Test module with ordered methods"""
439
-
440
- def __init__(self, execution_tracker):
441
- super().__init__()
442
- self.execution_tracker = execution_tracker
443
-
444
- @brainstate.nn.call_order(0)
445
- def ordered_method_0(self):
446
- self.execution_tracker.append('order_0')
447
-
448
- @brainstate.nn.call_order(1)
449
- def ordered_method_1(self):
450
- self.execution_tracker.append('order_1')
451
-
452
- @brainstate.nn.call_order(2)
453
- def ordered_method_2(self):
454
- self.execution_tracker.append('order_2')
455
-
456
- def unordered_method(self):
457
- self.execution_tracker.append('unordered')
458
-
459
-
460
- class NestedCallModule(brainstate.nn.Module):
461
- """Module with nested submodules for call testing"""
462
-
463
- def __init__(self):
464
- super().__init__()
465
- self.child1 = CustomMethodModule()
466
- self.child2 = CustomMethodModule()
467
- self.method_called = False
468
-
469
- def custom_method(self, *args, **kwargs):
470
- self.method_called = True
471
-
472
-
473
- class Test_call_order:
474
- """Comprehensive tests for call_order decorator"""
475
-
476
- def test_basic_call_order(self):
477
- """Test basic call_order decorator"""
478
- execution_tracker = []
479
-
480
- class TestModule(brainstate.nn.Module):
481
- @brainstate.nn.call_order(0)
482
- def method(self):
483
- execution_tracker.append('executed')
484
-
485
- module = TestModule()
486
- assert hasattr(module.method, 'call_order')
487
- assert module.method.call_order == 0
488
-
489
- def test_different_order_levels(self):
490
- """Test different order levels"""
491
- for level in [0, 1, 5, 9]:
492
- class TestModule(brainstate.nn.Module):
493
- @brainstate.nn.call_order(level)
494
- def method(self):
495
- pass
496
-
497
- module = TestModule()
498
- assert module.method.call_order == level
499
-
500
- def test_order_boundary_validation(self):
501
- """Test that order level boundary validation works"""
502
- # Valid levels (0 to MAX_ORDER-1)
503
- for level in range(brainstate.nn._collective_ops.MAX_ORDER):
504
- @brainstate.nn.call_order(level)
505
- def valid_method():
506
- pass
507
-
508
- assert valid_method.call_order == level
509
-
510
- # Invalid levels
511
- with pytest.raises(ValueError, match="must be an integer"):
512
- @brainstate.nn.call_order(-1)
513
- def invalid_method1():
514
- pass
515
-
516
- with pytest.raises(ValueError, match="must be an integer"):
517
- @brainstate.nn.call_order(brainstate.nn._collective_ops.MAX_ORDER)
518
- def invalid_method2():
519
- pass
520
-
521
- def test_disable_boundary_check(self):
522
- """Test disabling boundary check"""
523
-
524
- @brainstate.nn.call_order(100, check_order_boundary=False)
525
- def method():
526
- pass
527
-
528
- assert method.call_order == 100
529
-
530
- @brainstate.nn.call_order(-5, check_order_boundary=False)
531
- def method2():
532
- pass
533
-
534
- assert method2.call_order == -5
535
-
536
- def test_order_preserved_on_methods(self):
537
- """Test that call_order is preserved on instance methods"""
538
- execution_tracker = []
539
- module = OrderedCallModule(execution_tracker)
540
-
541
- assert module.ordered_method_0.call_order == 0
542
- assert module.ordered_method_1.call_order == 1
543
- assert module.ordered_method_2.call_order == 2
544
- assert not hasattr(module.unordered_method, 'call_order')
545
-
546
- def test_multiple_decorators(self):
547
- """Test applying call_order to multiple methods"""
548
- execution_tracker = []
549
-
550
- class MultiMethodModule(brainstate.nn.Module):
551
- @brainstate.nn.call_order(2)
552
- def method_a(self):
553
- execution_tracker.append('a')
554
-
555
- @brainstate.nn.call_order(0)
556
- def method_b(self):
557
- execution_tracker.append('b')
558
-
559
- @brainstate.nn.call_order(1)
560
- def method_c(self):
561
- execution_tracker.append('c')
562
-
563
- module = MultiMethodModule()
564
- assert module.method_a.call_order == 2
565
- assert module.method_b.call_order == 0
566
- assert module.method_c.call_order == 1
567
-
568
-
569
- class Test_call_all_fns:
570
- """Comprehensive tests for call_all_fns function"""
571
-
572
- def test_basic_function_call(self):
573
- """Test basic function calling"""
574
- module = CustomMethodModule()
575
-
576
- assert not module.method_called
577
- brainstate.nn.call_all_fns(module, 'custom_method')
578
- assert module.method_called
579
- assert module.call_count == 1
580
-
581
- def test_with_positional_args(self):
582
- """Test call_all_fns with positional arguments"""
583
- module = CustomMethodModule()
584
-
585
- brainstate.nn.call_all_fns(module, 'custom_method', (1, 2, 3))
586
-
587
- assert module.method_called
588
- assert module.received_args == (1, 2, 3)
589
-
590
- def test_with_keyword_args(self):
591
- """Test call_all_fns with keyword arguments"""
592
- module = CustomMethodModule()
593
-
594
- brainstate.nn.call_all_fns(module, 'custom_method', kwargs={'foo': 'bar', 'baz': 42})
595
-
596
- assert module.method_called
597
- assert module.received_kwargs == {'foo': 'bar', 'baz': 42}
598
-
599
- def test_with_mixed_args(self):
600
- """Test call_all_fns with both positional and keyword arguments"""
601
- module = CustomMethodModule()
602
-
603
- brainstate.nn.call_all_fns(
604
- module,
605
- 'custom_method',
606
- args=(1, 2),
607
- kwargs={'key': 'value'}
608
- )
609
-
610
- assert module.method_called
611
- assert module.received_args == (1, 2)
612
- assert module.received_kwargs == {'key': 'value'}
613
-
614
- def test_nested_modules(self):
615
- """Test that call_all_fns calls methods on nested modules"""
616
- module = NestedCallModule()
617
-
618
- brainstate.nn.call_all_fns(module, 'custom_method')
619
-
620
- assert module.method_called
621
- assert module.child1.method_called
622
- assert module.child2.method_called
623
-
624
- def test_call_order_respected(self):
625
- """Test that call_order is respected"""
626
- execution_tracker = []
627
-
628
- class ParentModule(brainstate.nn.Module):
629
- def __init__(self):
630
- super().__init__()
631
- self.child1 = OrderedCallModule(execution_tracker)
632
- self.child2 = OrderedCallModule(execution_tracker)
633
-
634
- module = ParentModule()
635
-
636
- # Call ordered_method_1 on all modules (parent doesn't have it, so skip)
637
- brainstate.nn.call_all_fns(module, 'ordered_method_1', fn_if_not_exist='pass')
638
-
639
- # Should be called on both children (both have order 1)
640
- assert execution_tracker.count('order_1') == 2
641
-
642
- def test_execution_order_with_mixed_decorators(self):
643
- """Test execution order with both ordered and unordered methods"""
644
- execution_tracker = []
645
-
646
- class MixedModule(brainstate.nn.Module):
647
- def __init__(self):
648
- super().__init__()
649
- self.ordered = OrderedCallModule(execution_tracker)
650
-
651
- def unordered_method(self):
652
- execution_tracker.append('parent_unordered')
653
-
654
- module = MixedModule()
655
-
656
- # Call unordered_method - parent has no decorator, child has no decorator
657
- execution_tracker.clear()
658
- brainstate.nn.call_all_fns(module, 'unordered_method')
659
-
660
- # Both should be called (unordered methods execute first)
661
- assert 'parent_unordered' in execution_tracker
662
- assert 'unordered' in execution_tracker
663
-
664
- def test_node_to_exclude(self):
665
- """Test excluding specific nodes"""
666
- module = NestedCallModule()
667
-
668
- # Exclude the parent module
669
- brainstate.nn.call_all_fns(
670
- module,
671
- 'custom_method',
672
- node_to_exclude=NestedCallModule
673
- )
674
-
675
- # Parent should not be called, but children should be
676
- assert not module.method_called
677
- assert module.child1.method_called
678
- assert module.child2.method_called
679
-
680
- def test_fn_if_not_exist_raise(self):
681
- """Test fn_if_not_exist='raise' behavior"""
682
- module = CustomMethodModule()
683
-
684
- with pytest.raises(AttributeError, match="does not have method"):
685
- brainstate.nn.call_all_fns(module, 'nonexistent_method', fn_if_not_exist='raise')
686
-
687
- def test_fn_if_not_exist_pass(self):
688
- """Test fn_if_not_exist='pass' behavior"""
689
- module = CustomMethodModule()
690
-
691
- # Should not raise error
692
- brainstate.nn.call_all_fns(module, 'nonexistent_method', fn_if_not_exist='pass')
693
-
694
- def test_fn_if_not_exist_none(self):
695
- """Test fn_if_not_exist='none' behavior"""
696
- module = CustomMethodModule()
697
-
698
- # Should not raise error
699
- brainstate.nn.call_all_fns(module, 'nonexistent_method', fn_if_not_exist='none')
700
-
701
- def test_fn_if_not_exist_warn(self):
702
- """Test fn_if_not_exist='warn' behavior"""
703
- module = CustomMethodModule()
704
-
705
- # Should issue warning but not raise
706
- with pytest.warns(UserWarning, match="does not have method"):
707
- brainstate.nn.call_all_fns(module, 'nonexistent_method', fn_if_not_exist='warn')
708
-
709
- def test_invalid_fn_name_type(self):
710
- """Test that invalid fn_name type raises error"""
711
- module = CustomMethodModule()
712
-
713
- with pytest.raises(TypeError, match="fn_name must be a string"):
714
- brainstate.nn.call_all_fns(module, 123)
715
-
716
- def test_invalid_kwargs_type(self):
717
- """Test that invalid kwargs type raises error"""
718
- module = CustomMethodModule()
719
-
720
- with pytest.raises(TypeError, match="kwargs must be a mapping"):
721
- brainstate.nn.call_all_fns(module, 'custom_method', kwargs=[1, 2, 3])
722
-
723
- def test_non_callable_attribute(self):
724
- """Test that non-callable attributes raise error"""
725
-
726
- class ModuleWithAttribute(brainstate.nn.Module):
727
- def __init__(self):
728
- super().__init__()
729
- self.my_attr = "not callable"
730
-
731
- module = ModuleWithAttribute()
732
-
733
- with pytest.raises(TypeError, match="must be callable"):
734
- brainstate.nn.call_all_fns(module, 'my_attr')
735
-
736
- def test_with_gru_cell(self):
737
- """Test with real GRU cell"""
738
- gru = brainstate.nn.GRUCell(5, 10)
739
-
740
- # Initialize states
741
- brainstate.nn.call_all_fns(gru, 'init_state', kwargs={'batch_size': 8})
742
-
743
- # Verify states were created
744
- states = gru.states()
745
- assert len(states) > 0
746
-
747
- def test_multiple_calls_same_function(self):
748
- """Test calling same function multiple times"""
749
- module = CustomMethodModule()
750
-
751
- for i in range(5):
752
- brainstate.nn.call_all_fns(module, 'custom_method')
753
-
754
- assert module.call_count == 5
755
-
756
- def test_single_non_tuple_arg(self):
757
- """Test that single non-tuple argument is wrapped"""
758
- module = CustomMethodModule()
759
-
760
- brainstate.nn.call_all_fns(module, 'custom_method', args=42)
761
-
762
- assert module.received_args == (42,)
763
-
764
- def test_sequential_module(self):
765
- """Test with Sequential module"""
766
- layer1 = CustomMethodModule()
767
- layer2 = CustomMethodModule()
768
- seq = brainstate.nn.Sequential(layer1, layer2)
769
-
770
- # Sequential itself doesn't have custom_method, so skip it
771
- brainstate.nn.call_all_fns(seq, 'custom_method', fn_if_not_exist='pass')
772
-
773
- assert layer1.method_called
774
- assert layer2.method_called
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+
19
+ import jax.numpy as jnp
20
+ import pytest
21
+
22
+ import brainstate
23
+
24
+
25
+ class SimpleTestModule(brainstate.nn.Module):
26
+ """Simple test module with init_state method"""
27
+
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.state_initialized = False
31
+ self.init_args = None
32
+ self.init_kwargs = None
33
+
34
+ def init_state(self, *args, **kwargs):
35
+ self.state_initialized = True
36
+ self.init_args = args
37
+ self.init_kwargs = kwargs
38
+ self.state = brainstate.State(jnp.zeros(5))
39
+
40
+
41
+ class OrderedTestModule(brainstate.nn.Module):
42
+ """Test module with call_order decorator"""
43
+
44
+ def __init__(self, order_level):
45
+ super().__init__()
46
+ self.order_level = order_level
47
+ self.call_sequence = []
48
+
49
+ @brainstate.nn.call_order(0)
50
+ def init_state(self):
51
+ self.state = brainstate.State(jnp.array([self.order_level]))
52
+
53
+
54
+ class NestedModule(brainstate.nn.Module):
55
+ """Module with nested submodules"""
56
+
57
+ def __init__(self):
58
+ super().__init__()
59
+ self.submodule1 = SimpleTestModule()
60
+ self.submodule2 = SimpleTestModule()
61
+
62
+ def init_state(self):
63
+ self.state = brainstate.State(jnp.zeros(3))
64
+
65
+
66
+ class Test_init_all_states:
67
+ """Comprehensive tests for init_all_states function"""
68
+
69
+ def test_basic_initialization(self):
70
+ """Test basic state initialization"""
71
+ module = SimpleTestModule()
72
+ assert not module.state_initialized
73
+
74
+ brainstate.nn.init_all_states(module)
75
+
76
+ assert module.state_initialized
77
+ assert module.state.value.shape == (5,)
78
+
79
+ def test_with_positional_args(self):
80
+ """Test init_all_states with positional arguments"""
81
+ module = SimpleTestModule()
82
+
83
+ brainstate.nn.init_all_states(module, 1, 2, 3)
84
+
85
+ assert module.state_initialized
86
+ assert module.init_args == (1, 2, 3)
87
+
88
+ def test_with_keyword_args(self):
89
+ """Test init_all_states with keyword arguments"""
90
+ module = SimpleTestModule()
91
+
92
+ brainstate.nn.init_all_states(module, batch_size=10, seq_len=20)
93
+
94
+ assert module.state_initialized
95
+ assert module.init_kwargs == {'batch_size': 10, 'seq_len': 20}
96
+
97
+ def test_with_mixed_args(self):
98
+ """Test init_all_states with both positional and keyword arguments"""
99
+ module = SimpleTestModule()
100
+
101
+ brainstate.nn.init_all_states(module, 42, batch_size=10)
102
+
103
+ assert module.state_initialized
104
+ assert module.init_args == (42,)
105
+ assert module.init_kwargs == {'batch_size': 10}
106
+
107
+ def test_nested_modules(self):
108
+ """Test that init_all_states initializes nested submodules"""
109
+ module = NestedModule()
110
+
111
+ brainstate.nn.init_all_states(module)
112
+
113
+ assert module.submodule1.state_initialized
114
+ assert module.submodule2.state_initialized
115
+ assert hasattr(module, 'state')
116
+
117
+ def test_with_gru_cell(self):
118
+ """Test with real GRUCell module"""
119
+ gru = brainstate.nn.GRUCell(1, 2)
120
+
121
+ brainstate.nn.init_all_states(gru, batch_size=10)
122
+
123
+ # Check that states were created
124
+ state_dict = gru.states()
125
+ assert len(state_dict) > 0
126
+
127
+ def test_sequential_module(self):
128
+ """Test with Sequential module containing multiple layers"""
129
+ seq = brainstate.nn.Sequential(
130
+ brainstate.nn.Linear(10, 20),
131
+ brainstate.nn.Dropout(0.5)
132
+ )
133
+
134
+ brainstate.nn.init_all_states(seq)
135
+
136
+ # Check that Linear layer has weight and bias
137
+ state_dict = seq.states()
138
+ assert len(state_dict) > 0
139
+
140
+ def test_node_to_exclude(self):
141
+ """Test excluding specific nodes from initialization"""
142
+ module = NestedModule()
143
+
144
+ # Exclude submodule1 by type matching - simpler and more reliable
145
+ brainstate.nn.init_all_states(
146
+ module,
147
+ node_to_exclude=NestedModule # Exclude the parent, only init children
148
+ )
149
+
150
+ # Parent should not be initialized, but children should be
151
+ assert not hasattr(module, 'state') or module.state is None or not hasattr(module.state, 'value')
152
+ assert module.submodule1.state_initialized
153
+ assert module.submodule2.state_initialized
154
+
155
+ def test_with_call_order(self):
156
+ """Test that call_order is respected during initialization"""
157
+
158
+ class OrderedModule(brainstate.nn.Module):
159
+ def __init__(self):
160
+ super().__init__()
161
+ self.execution_order = []
162
+
163
+ @brainstate.nn.call_order(1)
164
+ def init_state(self):
165
+ self.execution_order.append('parent')
166
+
167
+ class ChildModule(brainstate.nn.Module):
168
+ def __init__(self, parent_module):
169
+ super().__init__()
170
+ self.parent = parent_module
171
+
172
+ @brainstate.nn.call_order(0)
173
+ def init_state(self):
174
+ self.parent.execution_order.append('child')
175
+
176
+ parent = OrderedModule()
177
+ child = ChildModule(parent)
178
+ parent.child_module = child
179
+
180
+ brainstate.nn.init_all_states(parent)
181
+
182
+ # Child (order 0) should execute before parent (order 1)
183
+ assert parent.execution_order == ['child', 'parent']
184
+
185
+
186
+ class ResetTestModule(brainstate.nn.Module):
187
+ """Test module with both init_state and reset_state methods"""
188
+
189
+ def __init__(self):
190
+ super().__init__()
191
+ self.state_initialized = False
192
+ self.state_reset = False
193
+ self.reset_args = None
194
+ self.reset_kwargs = None
195
+
196
+ def init_state(self, *args, **kwargs):
197
+ self.state_initialized = True
198
+ self.state_reset = False
199
+ self.state = brainstate.State(jnp.ones(5))
200
+
201
+ def reset_state(self, *args, **kwargs):
202
+ self.state_reset = True
203
+ self.reset_args = args
204
+ self.reset_kwargs = kwargs
205
+ if hasattr(self, 'state'):
206
+ self.state.value = jnp.zeros(5)
207
+
208
+
209
+ class ResetOrderedModule(brainstate.nn.Module):
210
+ """Test module with call_order on reset_state"""
211
+
212
+ def __init__(self, order_level, execution_tracker):
213
+ super().__init__()
214
+ self.order_level = order_level
215
+ self.execution_tracker = execution_tracker
216
+
217
+ def init_state(self):
218
+ self.state = brainstate.State(jnp.ones(3))
219
+
220
+ @brainstate.nn.call_order(0)
221
+ def reset_state(self):
222
+ self.execution_tracker.append(f'order_{self.order_level}')
223
+ self.state.value = jnp.zeros(3)
224
+
225
+
226
+ class NestedResetModule(brainstate.nn.Module):
227
+ """Module with nested submodules that have reset_state"""
228
+
229
+ def __init__(self):
230
+ super().__init__()
231
+ self.submodule1 = ResetTestModule()
232
+ self.submodule2 = ResetTestModule()
233
+
234
+ def init_state(self):
235
+ self.state = brainstate.State(jnp.ones(3))
236
+
237
+ def reset_state(self):
238
+ self.state.value = jnp.zeros(3)
239
+
240
+
241
+ class Test_reset_all_states:
242
+ """Comprehensive tests for reset_all_states function"""
243
+
244
+ def test_basic_reset(self):
245
+ """Test basic state reset"""
246
+ module = ResetTestModule()
247
+ brainstate.nn.init_all_states(module)
248
+
249
+ assert module.state_initialized
250
+ assert not module.state_reset
251
+ assert jnp.allclose(module.state.value, jnp.ones(5))
252
+
253
+ brainstate.nn.reset_all_states(module)
254
+
255
+ assert module.state_reset
256
+ assert jnp.allclose(module.state.value, jnp.zeros(5))
257
+
258
+ def test_with_positional_args(self):
259
+ """Test reset_all_states with positional arguments"""
260
+ module = ResetTestModule()
261
+ brainstate.nn.init_all_states(module)
262
+
263
+ brainstate.nn.reset_all_states(module, 1, 2, 3)
264
+
265
+ assert module.state_reset
266
+ assert module.reset_args == (1, 2, 3)
267
+
268
+ def test_with_keyword_args(self):
269
+ """Test reset_all_states with keyword arguments"""
270
+ module = ResetTestModule()
271
+ brainstate.nn.init_all_states(module)
272
+
273
+ brainstate.nn.reset_all_states(module, batch_size=10, seq_len=20)
274
+
275
+ assert module.state_reset
276
+ assert module.reset_kwargs == {'batch_size': 10, 'seq_len': 20}
277
+
278
+ def test_with_mixed_args(self):
279
+ """Test reset_all_states with both positional and keyword arguments"""
280
+ module = ResetTestModule()
281
+ brainstate.nn.init_all_states(module)
282
+
283
+ brainstate.nn.reset_all_states(module, 42, batch_size=10)
284
+
285
+ assert module.state_reset
286
+ assert module.reset_args == (42,)
287
+ assert module.reset_kwargs == {'batch_size': 10}
288
+
289
+ def test_nested_modules(self):
290
+ """Test that reset_all_states resets nested submodules"""
291
+ module = NestedResetModule()
292
+ brainstate.nn.init_all_states(module)
293
+
294
+ # Verify initial state
295
+ assert jnp.allclose(module.state.value, jnp.ones(3))
296
+ assert jnp.allclose(module.submodule1.state.value, jnp.ones(5))
297
+ assert jnp.allclose(module.submodule2.state.value, jnp.ones(5))
298
+
299
+ brainstate.nn.reset_all_states(module)
300
+
301
+ # Verify all states were reset
302
+ assert jnp.allclose(module.state.value, jnp.zeros(3))
303
+ assert module.submodule1.state_reset
304
+ assert module.submodule2.state_reset
305
+ assert jnp.allclose(module.submodule1.state.value, jnp.zeros(5))
306
+ assert jnp.allclose(module.submodule2.state.value, jnp.zeros(5))
307
+
308
+ def test_with_gru_cell(self):
309
+ """Test reset with real GRUCell module"""
310
+ gru = brainstate.nn.GRUCell(5, 10)
311
+ brainstate.nn.init_all_states(gru, batch_size=8)
312
+
313
+ # Get initial state
314
+ initial_states = {k: v.value.copy() for k, v in gru.states().items()
315
+ if hasattr(v.value, 'copy') and not isinstance(v.value, dict)}
316
+
317
+ # Reset state
318
+ brainstate.nn.reset_all_states(gru, batch_size=8)
319
+
320
+ # Verify state was reset (should be zeros for hidden state)
321
+ for key in initial_states:
322
+ current_val = gru.states()[key].value
323
+ if not isinstance(current_val, dict):
324
+ # Hidden state should be reset to zeros
325
+ if 'h' in str(key):
326
+ assert jnp.allclose(current_val, jnp.zeros_like(current_val))
327
+
328
+ def test_sequential_reset(self):
329
+ """Test reset with Sequential module"""
330
+
331
+ # Create a simple network with stateful components
332
+ class StatefulLayer(brainstate.nn.Module):
333
+ def __init__(self):
334
+ super().__init__()
335
+ self.reset_called = False
336
+
337
+ def init_state(self):
338
+ self.state = brainstate.State(jnp.ones(5))
339
+
340
+ def reset_state(self):
341
+ self.reset_called = True
342
+ self.state.value = jnp.zeros(5)
343
+
344
+ layer1 = StatefulLayer()
345
+ layer2 = StatefulLayer()
346
+ seq = brainstate.nn.Sequential(layer1, layer2)
347
+
348
+ brainstate.nn.init_all_states(seq)
349
+ brainstate.nn.reset_all_states(seq)
350
+
351
+ assert layer1.reset_called
352
+ assert layer2.reset_called
353
+
354
+ def test_node_to_exclude(self):
355
+ """Test excluding specific nodes from reset"""
356
+ module = NestedResetModule()
357
+ brainstate.nn.init_all_states(module)
358
+
359
+ # Exclude the parent module from reset
360
+ brainstate.nn.reset_all_states(
361
+ module,
362
+ node_to_exclude=NestedResetModule
363
+ )
364
+
365
+ # Parent should not be reset, but children should be
366
+ assert jnp.allclose(module.state.value, jnp.ones(3)) # Not reset
367
+ assert module.submodule1.state_reset # Reset
368
+ assert module.submodule2.state_reset # Reset
369
+
370
+ def test_with_call_order(self):
371
+ """Test that call_order is respected during reset"""
372
+ execution_tracker = []
373
+
374
+ class OrderedParent(brainstate.nn.Module):
375
+ def __init__(self):
376
+ super().__init__()
377
+ self.child1 = ResetOrderedModule(1, execution_tracker)
378
+ self.child2 = ResetOrderedModule(2, execution_tracker)
379
+
380
+ def init_state(self):
381
+ pass
382
+
383
+ parent = OrderedParent()
384
+ brainstate.nn.init_all_states(parent)
385
+
386
+ execution_tracker.clear()
387
+ brainstate.nn.reset_all_states(parent)
388
+
389
+ # Both should execute (order 0), check that reset was called
390
+ assert len(execution_tracker) == 2
391
+
392
+ def test_multiple_resets(self):
393
+ """Test calling reset_all_states multiple times"""
394
+ module = ResetTestModule()
395
+ brainstate.nn.init_all_states(module)
396
+
397
+ for i in range(3):
398
+ brainstate.nn.reset_all_states(module)
399
+ assert module.state_reset
400
+ assert jnp.allclose(module.state.value, jnp.zeros(5))
401
+
402
+ def test_reset_without_init(self):
403
+ """Test that reset works even if init wasn't called explicitly"""
404
+ gru = brainstate.nn.GRUCell(5, 10)
405
+
406
+ # Initialize first
407
+ brainstate.nn.init_all_states(gru, batch_size=8)
408
+
409
+ # Reset should work
410
+ brainstate.nn.reset_all_states(gru, batch_size=8)
411
+
412
+ # Verify it didn't crash
413
+ states = gru.states()
414
+ assert len(states) > 0
415
+
416
+
417
+ class CustomMethodModule(brainstate.nn.Module):
418
+ """Test module with custom methods for call_all_fns testing"""
419
+
420
+ def __init__(self):
421
+ super().__init__()
422
+ self.method_called = False
423
+ self.call_count = 0
424
+ self.received_args = None
425
+ self.received_kwargs = None
426
+
427
+ def custom_method(self, *args, **kwargs):
428
+ self.method_called = True
429
+ self.call_count += 1
430
+ self.received_args = args
431
+ self.received_kwargs = kwargs
432
+
433
+ def another_method(self):
434
+ self.call_count += 10
435
+
436
+
437
+ class OrderedCallModule(brainstate.nn.Module):
438
+ """Test module with ordered methods"""
439
+
440
+ def __init__(self, execution_tracker):
441
+ super().__init__()
442
+ self.execution_tracker = execution_tracker
443
+
444
+ @brainstate.nn.call_order(0)
445
+ def ordered_method_0(self):
446
+ self.execution_tracker.append('order_0')
447
+
448
+ @brainstate.nn.call_order(1)
449
+ def ordered_method_1(self):
450
+ self.execution_tracker.append('order_1')
451
+
452
+ @brainstate.nn.call_order(2)
453
+ def ordered_method_2(self):
454
+ self.execution_tracker.append('order_2')
455
+
456
+ def unordered_method(self):
457
+ self.execution_tracker.append('unordered')
458
+
459
+
460
+ class NestedCallModule(brainstate.nn.Module):
461
+ """Module with nested submodules for call testing"""
462
+
463
+ def __init__(self):
464
+ super().__init__()
465
+ self.child1 = CustomMethodModule()
466
+ self.child2 = CustomMethodModule()
467
+ self.method_called = False
468
+
469
+ def custom_method(self, *args, **kwargs):
470
+ self.method_called = True
471
+
472
+
473
+ class Test_call_order:
474
+ """Comprehensive tests for call_order decorator"""
475
+
476
+ def test_basic_call_order(self):
477
+ """Test basic call_order decorator"""
478
+ execution_tracker = []
479
+
480
+ class TestModule(brainstate.nn.Module):
481
+ @brainstate.nn.call_order(0)
482
+ def method(self):
483
+ execution_tracker.append('executed')
484
+
485
+ module = TestModule()
486
+ assert hasattr(module.method, 'call_order')
487
+ assert module.method.call_order == 0
488
+
489
+ def test_different_order_levels(self):
490
+ """Test different order levels"""
491
+ for level in [0, 1, 5, 9]:
492
+ class TestModule(brainstate.nn.Module):
493
+ @brainstate.nn.call_order(level)
494
+ def method(self):
495
+ pass
496
+
497
+ module = TestModule()
498
+ assert module.method.call_order == level
499
+
500
+ def test_order_boundary_validation(self):
501
+ """Test that order level boundary validation works"""
502
+ # Valid levels (0 to MAX_ORDER-1)
503
+ for level in range(brainstate.nn._collective_ops.MAX_ORDER):
504
+ @brainstate.nn.call_order(level)
505
+ def valid_method():
506
+ pass
507
+
508
+ assert valid_method.call_order == level
509
+
510
+ # Invalid levels
511
+ with pytest.raises(ValueError, match="must be an integer"):
512
+ @brainstate.nn.call_order(-1)
513
+ def invalid_method1():
514
+ pass
515
+
516
+ with pytest.raises(ValueError, match="must be an integer"):
517
+ @brainstate.nn.call_order(brainstate.nn._collective_ops.MAX_ORDER)
518
+ def invalid_method2():
519
+ pass
520
+
521
+ def test_disable_boundary_check(self):
522
+ """Test disabling boundary check"""
523
+
524
+ @brainstate.nn.call_order(100, check_order_boundary=False)
525
+ def method():
526
+ pass
527
+
528
+ assert method.call_order == 100
529
+
530
+ @brainstate.nn.call_order(-5, check_order_boundary=False)
531
+ def method2():
532
+ pass
533
+
534
+ assert method2.call_order == -5
535
+
536
+ def test_order_preserved_on_methods(self):
537
+ """Test that call_order is preserved on instance methods"""
538
+ execution_tracker = []
539
+ module = OrderedCallModule(execution_tracker)
540
+
541
+ assert module.ordered_method_0.call_order == 0
542
+ assert module.ordered_method_1.call_order == 1
543
+ assert module.ordered_method_2.call_order == 2
544
+ assert not hasattr(module.unordered_method, 'call_order')
545
+
546
+ def test_multiple_decorators(self):
547
+ """Test applying call_order to multiple methods"""
548
+ execution_tracker = []
549
+
550
+ class MultiMethodModule(brainstate.nn.Module):
551
+ @brainstate.nn.call_order(2)
552
+ def method_a(self):
553
+ execution_tracker.append('a')
554
+
555
+ @brainstate.nn.call_order(0)
556
+ def method_b(self):
557
+ execution_tracker.append('b')
558
+
559
+ @brainstate.nn.call_order(1)
560
+ def method_c(self):
561
+ execution_tracker.append('c')
562
+
563
+ module = MultiMethodModule()
564
+ assert module.method_a.call_order == 2
565
+ assert module.method_b.call_order == 0
566
+ assert module.method_c.call_order == 1
567
+
568
+
569
+ class Test_call_all_fns:
570
+ """Comprehensive tests for call_all_fns function"""
571
+
572
+ def test_basic_function_call(self):
573
+ """Test basic function calling"""
574
+ module = CustomMethodModule()
575
+
576
+ assert not module.method_called
577
+ brainstate.nn.call_all_fns(module, 'custom_method')
578
+ assert module.method_called
579
+ assert module.call_count == 1
580
+
581
+ def test_with_positional_args(self):
582
+ """Test call_all_fns with positional arguments"""
583
+ module = CustomMethodModule()
584
+
585
+ brainstate.nn.call_all_fns(module, 'custom_method', (1, 2, 3))
586
+
587
+ assert module.method_called
588
+ assert module.received_args == (1, 2, 3)
589
+
590
+ def test_with_keyword_args(self):
591
+ """Test call_all_fns with keyword arguments"""
592
+ module = CustomMethodModule()
593
+
594
+ brainstate.nn.call_all_fns(module, 'custom_method', kwargs={'foo': 'bar', 'baz': 42})
595
+
596
+ assert module.method_called
597
+ assert module.received_kwargs == {'foo': 'bar', 'baz': 42}
598
+
599
+ def test_with_mixed_args(self):
600
+ """Test call_all_fns with both positional and keyword arguments"""
601
+ module = CustomMethodModule()
602
+
603
+ brainstate.nn.call_all_fns(
604
+ module,
605
+ 'custom_method',
606
+ args=(1, 2),
607
+ kwargs={'key': 'value'}
608
+ )
609
+
610
+ assert module.method_called
611
+ assert module.received_args == (1, 2)
612
+ assert module.received_kwargs == {'key': 'value'}
613
+
614
+ def test_nested_modules(self):
615
+ """Test that call_all_fns calls methods on nested modules"""
616
+ module = NestedCallModule()
617
+
618
+ brainstate.nn.call_all_fns(module, 'custom_method')
619
+
620
+ assert module.method_called
621
+ assert module.child1.method_called
622
+ assert module.child2.method_called
623
+
624
+ def test_call_order_respected(self):
625
+ """Test that call_order is respected"""
626
+ execution_tracker = []
627
+
628
+ class ParentModule(brainstate.nn.Module):
629
+ def __init__(self):
630
+ super().__init__()
631
+ self.child1 = OrderedCallModule(execution_tracker)
632
+ self.child2 = OrderedCallModule(execution_tracker)
633
+
634
+ module = ParentModule()
635
+
636
+ # Call ordered_method_1 on all modules (parent doesn't have it, so skip)
637
+ brainstate.nn.call_all_fns(module, 'ordered_method_1', fn_if_not_exist='pass')
638
+
639
+ # Should be called on both children (both have order 1)
640
+ assert execution_tracker.count('order_1') == 2
641
+
642
+ def test_execution_order_with_mixed_decorators(self):
643
+ """Test execution order with both ordered and unordered methods"""
644
+ execution_tracker = []
645
+
646
+ class MixedModule(brainstate.nn.Module):
647
+ def __init__(self):
648
+ super().__init__()
649
+ self.ordered = OrderedCallModule(execution_tracker)
650
+
651
+ def unordered_method(self):
652
+ execution_tracker.append('parent_unordered')
653
+
654
+ module = MixedModule()
655
+
656
+ # Call unordered_method - parent has no decorator, child has no decorator
657
+ execution_tracker.clear()
658
+ brainstate.nn.call_all_fns(module, 'unordered_method')
659
+
660
+ # Both should be called (unordered methods execute first)
661
+ assert 'parent_unordered' in execution_tracker
662
+ assert 'unordered' in execution_tracker
663
+
664
+ def test_node_to_exclude(self):
665
+ """Test excluding specific nodes"""
666
+ module = NestedCallModule()
667
+
668
+ # Exclude the parent module
669
+ brainstate.nn.call_all_fns(
670
+ module,
671
+ 'custom_method',
672
+ node_to_exclude=NestedCallModule
673
+ )
674
+
675
+ # Parent should not be called, but children should be
676
+ assert not module.method_called
677
+ assert module.child1.method_called
678
+ assert module.child2.method_called
679
+
680
+ def test_fn_if_not_exist_raise(self):
681
+ """Test fn_if_not_exist='raise' behavior"""
682
+ module = CustomMethodModule()
683
+
684
+ with pytest.raises(AttributeError, match="does not have method"):
685
+ brainstate.nn.call_all_fns(module, 'nonexistent_method', fn_if_not_exist='raise')
686
+
687
+ def test_fn_if_not_exist_pass(self):
688
+ """Test fn_if_not_exist='pass' behavior"""
689
+ module = CustomMethodModule()
690
+
691
+ # Should not raise error
692
+ brainstate.nn.call_all_fns(module, 'nonexistent_method', fn_if_not_exist='pass')
693
+
694
+ def test_fn_if_not_exist_none(self):
695
+ """Test fn_if_not_exist='none' behavior"""
696
+ module = CustomMethodModule()
697
+
698
+ # Should not raise error
699
+ brainstate.nn.call_all_fns(module, 'nonexistent_method', fn_if_not_exist='none')
700
+
701
+ def test_fn_if_not_exist_warn(self):
702
+ """Test fn_if_not_exist='warn' behavior"""
703
+ module = CustomMethodModule()
704
+
705
+ # Should issue warning but not raise
706
+ with pytest.warns(UserWarning, match="does not have method"):
707
+ brainstate.nn.call_all_fns(module, 'nonexistent_method', fn_if_not_exist='warn')
708
+
709
+ def test_invalid_fn_name_type(self):
710
+ """Test that invalid fn_name type raises error"""
711
+ module = CustomMethodModule()
712
+
713
+ with pytest.raises(TypeError, match="fn_name must be a string"):
714
+ brainstate.nn.call_all_fns(module, 123)
715
+
716
+ def test_invalid_kwargs_type(self):
717
+ """Test that invalid kwargs type raises error"""
718
+ module = CustomMethodModule()
719
+
720
+ with pytest.raises(TypeError, match="kwargs must be a mapping"):
721
+ brainstate.nn.call_all_fns(module, 'custom_method', kwargs=[1, 2, 3])
722
+
723
+ def test_non_callable_attribute(self):
724
+ """Test that non-callable attributes raise error"""
725
+
726
+ class ModuleWithAttribute(brainstate.nn.Module):
727
+ def __init__(self):
728
+ super().__init__()
729
+ self.my_attr = "not callable"
730
+
731
+ module = ModuleWithAttribute()
732
+
733
+ with pytest.raises(TypeError, match="must be callable"):
734
+ brainstate.nn.call_all_fns(module, 'my_attr')
735
+
736
+ def test_with_gru_cell(self):
737
+ """Test with real GRU cell"""
738
+ gru = brainstate.nn.GRUCell(5, 10)
739
+
740
+ # Initialize states
741
+ brainstate.nn.call_all_fns(gru, 'init_state', kwargs={'batch_size': 8})
742
+
743
+ # Verify states were created
744
+ states = gru.states()
745
+ assert len(states) > 0
746
+
747
+ def test_multiple_calls_same_function(self):
748
+ """Test calling same function multiple times"""
749
+ module = CustomMethodModule()
750
+
751
+ for i in range(5):
752
+ brainstate.nn.call_all_fns(module, 'custom_method')
753
+
754
+ assert module.call_count == 5
755
+
756
+ def test_single_non_tuple_arg(self):
757
+ """Test that single non-tuple argument is wrapped"""
758
+ module = CustomMethodModule()
759
+
760
+ brainstate.nn.call_all_fns(module, 'custom_method', args=42)
761
+
762
+ assert module.received_args == (42,)
763
+
764
+ def test_sequential_module(self):
765
+ """Test with Sequential module"""
766
+ layer1 = CustomMethodModule()
767
+ layer2 = CustomMethodModule()
768
+ seq = brainstate.nn.Sequential(layer1, layer2)
769
+
770
+ # Sequential itself doesn't have custom_method, so skip it
771
+ brainstate.nn.call_all_fns(seq, 'custom_method', fn_if_not_exist='pass')
772
+
773
+ assert layer1.method_called
774
+ assert layer2.method_called