brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,593 +1,593 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Comprehensive tests for RNN cell implementations."""
17
-
18
- import unittest
19
- from typing import Type
20
-
21
- import jax
22
- import jax.numpy as jnp
23
- import numpy as np
24
-
25
- import brainstate
26
- import brainstate.nn as nn
27
- from brainstate.nn import RNNCell, ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell
28
- from brainstate.nn import init as init
29
- from brainstate.nn import _activations as functional
30
-
31
-
32
- class TestRNNCellBase(unittest.TestCase):
33
- """Base test class for all RNN cell implementations."""
34
-
35
- def setUp(self):
36
- """Set up test fixtures."""
37
- self.num_in = 10
38
- self.num_out = 20
39
- self.batch_size = 32
40
- self.sequence_length = 100
41
- self.seed = 42
42
-
43
- # Initialize random inputs
44
- key = jax.random.PRNGKey(self.seed)
45
- self.x = jax.random.normal(key, (self.batch_size, self.num_in))
46
- self.sequence = jax.random.normal(
47
- key, (self.sequence_length, self.batch_size, self.num_in)
48
- )
49
-
50
-
51
- class TestVanillaRNNCell(TestRNNCellBase):
52
- """Comprehensive tests for VanillaRNNCell."""
53
-
54
- def test_basic_forward(self):
55
- """Test basic forward pass."""
56
- cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
57
- cell.init_state(batch_size=self.batch_size)
58
-
59
- output = cell.update(self.x)
60
-
61
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
62
- self.assertFalse(jnp.any(jnp.isnan(output)))
63
- self.assertFalse(jnp.any(jnp.isinf(output)))
64
-
65
- def test_sequence_processing(self):
66
- """Test processing a sequence of inputs."""
67
- cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
68
- cell.init_state(batch_size=self.batch_size)
69
-
70
- outputs = []
71
- for t in range(self.sequence_length):
72
- output = cell.update(self.sequence[t])
73
- outputs.append(output)
74
-
75
- outputs = jnp.stack(outputs)
76
- self.assertEqual(outputs.shape, (self.sequence_length, self.batch_size, self.num_out))
77
- self.assertFalse(jnp.any(jnp.isnan(outputs)))
78
-
79
- def test_state_reset(self):
80
- """Test state reset functionality."""
81
- cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
82
- cell.init_state(batch_size=self.batch_size)
83
-
84
- # Process some input
85
- _ = cell.update(self.x)
86
- state_before = cell.h.value.copy()
87
-
88
- # Reset state
89
- cell.reset_state(batch_size=self.batch_size)
90
- state_after = cell.h.value.copy()
91
-
92
- # States should be different (unless randomly the same, which is unlikely)
93
- self.assertFalse(jnp.allclose(state_before, state_after, atol=1e-6))
94
-
95
- def test_different_batch_sizes(self):
96
- """Test with different batch sizes."""
97
- cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
98
-
99
- for batch_size in [1, 16, 64]:
100
- cell.init_state(batch_size=batch_size)
101
- x = jnp.ones((batch_size, self.num_in))
102
- output = cell.update(x)
103
- self.assertEqual(output.shape, (batch_size, self.num_out))
104
-
105
- def test_activation_functions(self):
106
- """Test different activation functions."""
107
- activations = ['relu', 'tanh', 'sigmoid', 'gelu']
108
-
109
- for activation in activations:
110
- cell = ValinaRNNCell(
111
- num_in=self.num_in,
112
- num_out=self.num_out,
113
- activation=activation
114
- )
115
- cell.init_state(batch_size=self.batch_size)
116
- output = cell.update(self.x)
117
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
118
- self.assertFalse(jnp.any(jnp.isnan(output)))
119
-
120
- def test_custom_initializers(self):
121
- """Test custom weight and state initializers."""
122
- cell = ValinaRNNCell(
123
- num_in=self.num_in,
124
- num_out=self.num_out,
125
- w_init=init.Orthogonal(),
126
- b_init=init.Constant(0.1),
127
- state_init=init.Normal(0.01)
128
- )
129
- cell.init_state(batch_size=self.batch_size)
130
- output = cell.update(self.x)
131
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
132
-
133
- def test_gradient_flow(self):
134
- """Test gradient flow through the cell."""
135
- cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
136
- cell.init_state(batch_size=self.batch_size)
137
-
138
- def loss_fn(x):
139
- output = cell.update(x)
140
- return jnp.mean(output ** 2)
141
-
142
- grad_fn = jax.grad(loss_fn)
143
- grad = grad_fn(self.x)
144
-
145
- self.assertEqual(grad.shape, self.x.shape)
146
- self.assertFalse(jnp.any(jnp.isnan(grad)))
147
- self.assertTrue(jnp.any(grad != 0)) # Gradients should be non-zero
148
-
149
-
150
- class TestGRUCell(TestRNNCellBase):
151
- """Comprehensive tests for GRUCell."""
152
-
153
- def test_basic_forward(self):
154
- """Test basic forward pass."""
155
- cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
156
- cell.init_state(batch_size=self.batch_size)
157
-
158
- output = cell.update(self.x)
159
-
160
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
161
- self.assertFalse(jnp.any(jnp.isnan(output)))
162
-
163
- def test_gating_mechanism(self):
164
- """Test that gating values are in valid range."""
165
- cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
166
- cell.init_state(batch_size=self.batch_size)
167
-
168
- # Access internal computation
169
- old_h = cell.h.value
170
- xh = jnp.concatenate([self.x, old_h], axis=-1)
171
- gates = functional.sigmoid(cell.Wrz(xh))
172
-
173
- # Gates should be between 0 and 1
174
- self.assertTrue(jnp.all(gates >= 0))
175
- self.assertTrue(jnp.all(gates <= 1))
176
-
177
- def test_state_persistence(self):
178
- """Test that state persists across updates."""
179
- cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
180
- cell.init_state(batch_size=self.batch_size)
181
-
182
- # Process sequence and track states
183
- states = []
184
- for t in range(10):
185
- _ = cell.update(self.sequence[t])
186
- states.append(cell.h.value.copy())
187
-
188
- # States should evolve over time
189
- for i in range(1, len(states)):
190
- self.assertFalse(jnp.allclose(states[i], states[i-1], atol=1e-8))
191
-
192
- def test_reset_vs_update_gates(self):
193
- """Test that reset and update gates behave differently."""
194
- cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
195
- cell.init_state(batch_size=self.batch_size)
196
-
197
- # Get gates for the same input
198
- old_h = cell.h.value
199
- xh = jnp.concatenate([self.x, old_h], axis=-1)
200
- r, z = jnp.split(functional.sigmoid(cell.Wrz(xh)), indices_or_sections=2, axis=-1)
201
-
202
- # Reset and update gates should be different
203
- self.assertFalse(jnp.allclose(r, z, atol=1e-6))
204
-
205
- def test_different_initializers(self):
206
- """Test with different weight initializers."""
207
- initializers = [
208
- init.XavierNormal(),
209
- init.XavierUniform(),
210
- init.Orthogonal(),
211
- init.KaimingNormal(),
212
- ]
213
-
214
- for w_init in initializers:
215
- cell = GRUCell(
216
- num_in=self.num_in,
217
- num_out=self.num_out,
218
- w_init=w_init
219
- )
220
- cell.init_state(batch_size=self.batch_size)
221
- output = cell.update(self.x)
222
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
223
-
224
-
225
- class TestMGUCell(TestRNNCellBase):
226
- """Comprehensive tests for MGUCell."""
227
-
228
- def test_basic_forward(self):
229
- """Test basic forward pass."""
230
- cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
231
- cell.init_state(batch_size=self.batch_size)
232
-
233
- output = cell.update(self.x)
234
-
235
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
236
- self.assertFalse(jnp.any(jnp.isnan(output)))
237
-
238
- def test_single_gate_mechanism(self):
239
- """Test that MGU uses single forget gate."""
240
- cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
241
- cell.init_state(batch_size=self.batch_size)
242
-
243
- # Check that only one gate is computed
244
- xh = jnp.concatenate([self.x, cell.h.value], axis=-1)
245
- f = functional.sigmoid(cell.Wf(xh))
246
-
247
- # Forget gate should be between 0 and 1
248
- self.assertTrue(jnp.all(f >= 0))
249
- self.assertTrue(jnp.all(f <= 1))
250
- self.assertEqual(f.shape, (self.batch_size, self.num_out))
251
-
252
- def test_parameter_efficiency(self):
253
- """Test that MGU has fewer parameters than GRU."""
254
- mgu_cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
255
- gru_cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
256
-
257
- # Count parameters - MGU should have fewer
258
- # MGU has 2 weight matrices (Wf, Wh)
259
- # GRU has 2 weight matrices but one is double size (Wrz, Wh)
260
- mgu_param_count = 2 * ((self.num_in + self.num_out) * self.num_out + self.num_out)
261
- gru_param_count = ((self.num_in + self.num_out) * (self.num_out * 2) + self.num_out * 2) + \
262
- ((self.num_in + self.num_out) * self.num_out + self.num_out)
263
-
264
- self.assertLess(mgu_param_count, gru_param_count)
265
-
266
-
267
- class TestLSTMCell(TestRNNCellBase):
268
- """Comprehensive tests for LSTMCell."""
269
-
270
- def test_basic_forward(self):
271
- """Test basic forward pass."""
272
- cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
273
- cell.init_state(batch_size=self.batch_size)
274
-
275
- output = cell.update(self.x)
276
-
277
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
278
- self.assertFalse(jnp.any(jnp.isnan(output)))
279
-
280
- def test_dual_state_mechanism(self):
281
- """Test that LSTM maintains both hidden and cell states."""
282
- cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
283
- cell.init_state(batch_size=self.batch_size)
284
-
285
- # Check initial states
286
- self.assertIsNotNone(cell.h)
287
- self.assertIsNotNone(cell.c)
288
- self.assertEqual(cell.h.value.shape, (self.batch_size, self.num_out))
289
- self.assertEqual(cell.c.value.shape, (self.batch_size, self.num_out))
290
-
291
- # Update and check states change
292
- h_before = cell.h.value.copy()
293
- c_before = cell.c.value.copy()
294
-
295
- _ = cell.update(self.x)
296
-
297
- self.assertFalse(jnp.allclose(cell.h.value, h_before, atol=1e-8))
298
- self.assertFalse(jnp.allclose(cell.c.value, c_before, atol=1e-8))
299
-
300
- def test_forget_gate_bias(self):
301
- """Test that forget gate has positive bias initialization."""
302
- cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
303
- cell.init_state(batch_size=self.batch_size)
304
-
305
- # Process with zero input to see bias effect
306
- zero_input = jnp.zeros((self.batch_size, self.num_in))
307
- xh = jnp.concatenate([zero_input, cell.h.value], axis=-1)
308
- gates = cell.W(xh)
309
- _, _, f, _ = jnp.split(gates, indices_or_sections=4, axis=-1)
310
- f_gate = functional.sigmoid(f + 1.) # Note the +1 bias
311
-
312
- # Forget gate should be biased towards remembering (> 0.5)
313
- self.assertTrue(jnp.mean(f_gate) > 0.5)
314
-
315
- def test_gate_values_range(self):
316
- """Test that all gates produce values in [0, 1]."""
317
- cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
318
- cell.init_state(batch_size=self.batch_size)
319
-
320
- xh = jnp.concatenate([self.x, cell.h.value], axis=-1)
321
- i, g, f, o = jnp.split(cell.W(xh), indices_or_sections=4, axis=-1)
322
-
323
- i_gate = functional.sigmoid(i)
324
- f_gate = functional.sigmoid(f + 1.)
325
- o_gate = functional.sigmoid(o)
326
-
327
- for gate in [i_gate, f_gate, o_gate]:
328
- self.assertTrue(jnp.all(gate >= 0))
329
- self.assertTrue(jnp.all(gate <= 1))
330
-
331
- def test_cell_state_gradient_flow(self):
332
- """Test gradient flow through cell state."""
333
- cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
334
- cell.init_state(batch_size=self.batch_size)
335
-
336
- def loss_fn(x):
337
- for t in range(10):
338
- _ = cell.update(x)
339
- return jnp.mean(cell.c.value ** 2)
340
-
341
- grad_fn = jax.grad(loss_fn)
342
- grad = grad_fn(self.x)
343
-
344
- self.assertFalse(jnp.any(jnp.isnan(grad)))
345
- self.assertTrue(jnp.any(grad != 0))
346
-
347
-
348
- class TestURLSTMCell(TestRNNCellBase):
349
- """Comprehensive tests for URLSTMCell."""
350
-
351
- def test_basic_forward(self):
352
- """Test basic forward pass."""
353
- cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
354
- cell.init_state(batch_size=self.batch_size)
355
-
356
- output = cell.update(self.x)
357
-
358
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
359
- self.assertFalse(jnp.any(jnp.isnan(output)))
360
-
361
- def test_untied_bias_mechanism(self):
362
- """Test the untied bias initialization."""
363
- cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
364
- cell.init_state(batch_size=self.batch_size)
365
-
366
- # Check bias values are initialized
367
- self.assertIsNotNone(cell.bias.value)
368
- self.assertEqual(cell.bias.value.shape, (self.num_out,))
369
-
370
- # Biases should be diverse (not all the same)
371
- self.assertGreater(jnp.std(cell.bias.value), 0.1)
372
-
373
- def test_unified_gate_computation(self):
374
- """Test the unified gate mechanism."""
375
- cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
376
- cell.init_state(batch_size=self.batch_size)
377
-
378
- h, c = cell.h.value, cell.c.value
379
- xh = jnp.concatenate([self.x, h], axis=-1)
380
- gates = cell.W(xh)
381
- f, r, u, o = jnp.split(gates, indices_or_sections=4, axis=-1)
382
-
383
- f_gate = functional.sigmoid(f + cell.bias.value)
384
- r_gate = functional.sigmoid(r - cell.bias.value)
385
-
386
- # Compute unified gate
387
- g = 2 * r_gate * f_gate + (1 - 2 * r_gate) * f_gate ** 2
388
-
389
- # Unified gate should be in [0, 1]
390
- self.assertTrue(jnp.all(g >= 0))
391
- self.assertTrue(jnp.all(g <= 1))
392
-
393
- def test_comparison_with_lstm(self):
394
- """Test that URLSTM behaves differently from standard LSTM."""
395
- urlstm = URLSTMCell(num_in=self.num_in, num_out=self.num_out, state_init=init.Constant(0.5))
396
- lstm = LSTMCell(num_in=self.num_in, num_out=self.num_out, state_init=init.Constant(0.5))
397
-
398
- urlstm.init_state(batch_size=self.batch_size)
399
- lstm.init_state(batch_size=self.batch_size)
400
-
401
- # Same input should produce different outputs
402
- urlstm_out = urlstm.update(self.x)
403
- lstm_out = lstm.update(self.x)
404
-
405
- self.assertFalse(jnp.allclose(urlstm_out, lstm_out, atol=1e-4))
406
-
407
-
408
- class TestRNNCellIntegration(TestRNNCellBase):
409
- """Integration tests for all RNN cells."""
410
-
411
- def test_all_cells_compatible_interface(self):
412
- """Test that all cells have compatible interfaces."""
413
- cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
414
-
415
- for CellType in cell_types:
416
- cell = CellType(num_in=self.num_in, num_out=self.num_out)
417
-
418
- # Test init_state
419
- cell.init_state(batch_size=self.batch_size)
420
-
421
- # Test update
422
- output = cell.update(self.x)
423
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
424
-
425
- # Test reset_state
426
- cell.reset_state(batch_size=16)
427
-
428
- # Test with new batch size
429
- x_small = jnp.ones((16, self.num_in))
430
- output_small = cell.update(x_small)
431
- self.assertEqual(output_small.shape, (16, self.num_out))
432
-
433
- def test_sequence_to_sequence(self):
434
- """Test sequence-to-sequence processing."""
435
- cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
436
-
437
- for CellType in cell_types:
438
- cell = CellType(num_in=self.num_in, num_out=self.num_out)
439
- cell.init_state(batch_size=self.batch_size)
440
-
441
- outputs = []
442
- for t in range(self.sequence_length):
443
- output = cell.update(self.sequence[t])
444
- outputs.append(output)
445
-
446
- outputs = jnp.stack(outputs)
447
- self.assertEqual(
448
- outputs.shape,
449
- (self.sequence_length, self.batch_size, self.num_out)
450
- )
451
-
452
- def test_variable_length_sequences(self):
453
- """Test handling of variable length sequences with masking."""
454
- cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
455
- cell.init_state(batch_size=self.batch_size)
456
-
457
- # Create mask for variable lengths
458
- lengths = jnp.array([10, 20, 30, 40] * (self.batch_size // 4))
459
- mask = jnp.arange(self.sequence_length)[:, None] < lengths[None, :]
460
-
461
- outputs = []
462
- for t in range(self.sequence_length):
463
- output = cell.update(self.sequence[t])
464
- # Apply mask
465
- output = output * mask[t:t+1].T
466
- outputs.append(output)
467
-
468
- outputs = jnp.stack(outputs)
469
-
470
- # Check that masked positions are zero
471
- for b in range(self.batch_size):
472
- length = lengths[b]
473
- if length < self.sequence_length:
474
- self.assertTrue(jnp.allclose(outputs[length:, b, :], 0.0))
475
-
476
- def test_gradient_clipping(self):
477
- """Test gradient clipping during training."""
478
- cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
479
- cell.init_state(batch_size=self.batch_size)
480
-
481
- def loss_fn(x):
482
- output = jnp.zeros((self.batch_size, self.num_out))
483
- for t in range(50): # Long sequence
484
- output = cell.update(x * (t + 1)) # Amplify input
485
- return jnp.mean(output ** 2)
486
-
487
- grad_fn = jax.grad(loss_fn)
488
- grad = grad_fn(self.x)
489
-
490
- # Gradients should not explode
491
- self.assertFalse(jnp.any(jnp.isnan(grad)))
492
- self.assertFalse(jnp.any(jnp.isinf(grad)))
493
- self.assertLess(jnp.max(jnp.abs(grad)), 1e6)
494
-
495
-
496
- class TestRNNCellEdgeCases(TestRNNCellBase):
497
- """Edge case tests for RNN cells."""
498
-
499
- def test_single_sample(self):
500
- """Test with batch size of 1."""
501
- cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
502
- cell.init_state(batch_size=1)
503
-
504
- x = jnp.ones((1, self.num_in))
505
- output = cell.update(x)
506
- self.assertEqual(output.shape, (1, self.num_out))
507
-
508
- def test_zero_input(self):
509
- """Test with zero inputs."""
510
- cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
511
-
512
- for CellType in cell_types:
513
- cell = CellType(num_in=self.num_in, num_out=self.num_out)
514
- cell.init_state(batch_size=self.batch_size)
515
-
516
- zero_input = jnp.zeros((self.batch_size, self.num_in))
517
- output = cell.update(zero_input)
518
-
519
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
520
- self.assertFalse(jnp.any(jnp.isnan(output)))
521
-
522
- def test_large_input_values(self):
523
- """Test with large input values."""
524
- cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
525
- cell.init_state(batch_size=self.batch_size)
526
-
527
- large_input = jnp.ones((self.batch_size, self.num_in)) * 100
528
- output = cell.update(large_input)
529
-
530
- # Should handle large inputs gracefully (sigmoid saturation)
531
- self.assertFalse(jnp.any(jnp.isnan(output)))
532
- self.assertFalse(jnp.any(jnp.isinf(output)))
533
-
534
- def test_very_long_sequence(self):
535
- """Test with very long sequences."""
536
- cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
537
- cell.init_state(batch_size=4) # Smaller batch for memory
538
-
539
- long_sequence = jnp.ones((1000, 4, self.num_in))
540
-
541
- final_output = None
542
- for t in range(1000):
543
- final_output = cell.update(long_sequence[t])
544
-
545
- # Should not have numerical issues even after long sequence
546
- self.assertFalse(jnp.any(jnp.isnan(final_output)))
547
- self.assertFalse(jnp.any(jnp.isinf(final_output)))
548
-
549
- def test_dimension_mismatch_error(self):
550
- """Test that dimension mismatches raise appropriate errors."""
551
- cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
552
- cell.init_state(batch_size=self.batch_size)
553
-
554
- # Wrong input dimension should raise error
555
- wrong_input = jnp.ones((self.batch_size, self.num_in + 5))
556
-
557
- with self.assertRaises(Exception):
558
- _ = cell.update(wrong_input)
559
-
560
-
561
- class TestRNNCellProperties(TestRNNCellBase):
562
- """Test cell properties and attributes."""
563
-
564
- def test_cell_attributes(self):
565
- """Test that cells have correct attributes."""
566
- cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
567
-
568
- self.assertEqual(cell.num_in, self.num_in)
569
- self.assertEqual(cell.num_out, self.num_out)
570
- self.assertEqual(cell.in_size, (self.num_in,))
571
- self.assertEqual(cell.out_size, (self.num_out,))
572
-
573
- def test_inheritance_structure(self):
574
- """Test that all cells inherit from RNNCell."""
575
- cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
576
-
577
- for CellType in cell_types:
578
- cell = CellType(num_in=self.num_in, num_out=self.num_out)
579
- self.assertIsInstance(cell, RNNCell)
580
-
581
- def test_docstring_presence(self):
582
- """Test that all cells have proper docstrings."""
583
- cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
584
-
585
- for CellType in cell_types:
586
- self.assertIsNotNone(CellType.__doc__)
587
- self.assertIn("Examples", CellType.__doc__)
588
- self.assertIn("Parameters", CellType.__doc__)
589
- self.assertIn(">>>", CellType.__doc__)
590
-
591
-
592
- if __name__ == '__main__':
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Comprehensive tests for RNN cell implementations."""
17
+
18
+ import unittest
19
+ from typing import Type
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
24
+
25
+ import brainstate
26
+ import brainstate.nn as nn
27
+ from brainstate.nn import RNNCell, ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell
28
+ from brainstate.nn import init as init
29
+ from brainstate.nn import _activations as functional
30
+
31
+
32
+ class TestRNNCellBase(unittest.TestCase):
33
+ """Base test class for all RNN cell implementations."""
34
+
35
+ def setUp(self):
36
+ """Set up test fixtures."""
37
+ self.num_in = 10
38
+ self.num_out = 20
39
+ self.batch_size = 32
40
+ self.sequence_length = 100
41
+ self.seed = 42
42
+
43
+ # Initialize random inputs
44
+ key = jax.random.PRNGKey(self.seed)
45
+ self.x = jax.random.normal(key, (self.batch_size, self.num_in))
46
+ self.sequence = jax.random.normal(
47
+ key, (self.sequence_length, self.batch_size, self.num_in)
48
+ )
49
+
50
+
51
+ class TestVanillaRNNCell(TestRNNCellBase):
52
+ """Comprehensive tests for VanillaRNNCell."""
53
+
54
+ def test_basic_forward(self):
55
+ """Test basic forward pass."""
56
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
57
+ cell.init_state(batch_size=self.batch_size)
58
+
59
+ output = cell.update(self.x)
60
+
61
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
62
+ self.assertFalse(jnp.any(jnp.isnan(output)))
63
+ self.assertFalse(jnp.any(jnp.isinf(output)))
64
+
65
+ def test_sequence_processing(self):
66
+ """Test processing a sequence of inputs."""
67
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
68
+ cell.init_state(batch_size=self.batch_size)
69
+
70
+ outputs = []
71
+ for t in range(self.sequence_length):
72
+ output = cell.update(self.sequence[t])
73
+ outputs.append(output)
74
+
75
+ outputs = jnp.stack(outputs)
76
+ self.assertEqual(outputs.shape, (self.sequence_length, self.batch_size, self.num_out))
77
+ self.assertFalse(jnp.any(jnp.isnan(outputs)))
78
+
79
+ def test_state_reset(self):
80
+ """Test state reset functionality."""
81
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
82
+ cell.init_state(batch_size=self.batch_size)
83
+
84
+ # Process some input
85
+ _ = cell.update(self.x)
86
+ state_before = cell.h.value.copy()
87
+
88
+ # Reset state
89
+ cell.reset_state(batch_size=self.batch_size)
90
+ state_after = cell.h.value.copy()
91
+
92
+ # States should be different (unless randomly the same, which is unlikely)
93
+ self.assertFalse(jnp.allclose(state_before, state_after, atol=1e-6))
94
+
95
+ def test_different_batch_sizes(self):
96
+ """Test with different batch sizes."""
97
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
98
+
99
+ for batch_size in [1, 16, 64]:
100
+ cell.init_state(batch_size=batch_size)
101
+ x = jnp.ones((batch_size, self.num_in))
102
+ output = cell.update(x)
103
+ self.assertEqual(output.shape, (batch_size, self.num_out))
104
+
105
+ def test_activation_functions(self):
106
+ """Test different activation functions."""
107
+ activations = ['relu', 'tanh', 'sigmoid', 'gelu']
108
+
109
+ for activation in activations:
110
+ cell = ValinaRNNCell(
111
+ num_in=self.num_in,
112
+ num_out=self.num_out,
113
+ activation=activation
114
+ )
115
+ cell.init_state(batch_size=self.batch_size)
116
+ output = cell.update(self.x)
117
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
118
+ self.assertFalse(jnp.any(jnp.isnan(output)))
119
+
120
+ def test_custom_initializers(self):
121
+ """Test custom weight and state initializers."""
122
+ cell = ValinaRNNCell(
123
+ num_in=self.num_in,
124
+ num_out=self.num_out,
125
+ w_init=init.Orthogonal(),
126
+ b_init=init.Constant(0.1),
127
+ state_init=init.Normal(0.01)
128
+ )
129
+ cell.init_state(batch_size=self.batch_size)
130
+ output = cell.update(self.x)
131
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
132
+
133
+ def test_gradient_flow(self):
134
+ """Test gradient flow through the cell."""
135
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
136
+ cell.init_state(batch_size=self.batch_size)
137
+
138
+ def loss_fn(x):
139
+ output = cell.update(x)
140
+ return jnp.mean(output ** 2)
141
+
142
+ grad_fn = jax.grad(loss_fn)
143
+ grad = grad_fn(self.x)
144
+
145
+ self.assertEqual(grad.shape, self.x.shape)
146
+ self.assertFalse(jnp.any(jnp.isnan(grad)))
147
+ self.assertTrue(jnp.any(grad != 0)) # Gradients should be non-zero
148
+
149
+
150
+ class TestGRUCell(TestRNNCellBase):
151
+ """Comprehensive tests for GRUCell."""
152
+
153
+ def test_basic_forward(self):
154
+ """Test basic forward pass."""
155
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
156
+ cell.init_state(batch_size=self.batch_size)
157
+
158
+ output = cell.update(self.x)
159
+
160
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
161
+ self.assertFalse(jnp.any(jnp.isnan(output)))
162
+
163
+ def test_gating_mechanism(self):
164
+ """Test that gating values are in valid range."""
165
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
166
+ cell.init_state(batch_size=self.batch_size)
167
+
168
+ # Access internal computation
169
+ old_h = cell.h.value
170
+ xh = jnp.concatenate([self.x, old_h], axis=-1)
171
+ gates = functional.sigmoid(cell.Wrz(xh))
172
+
173
+ # Gates should be between 0 and 1
174
+ self.assertTrue(jnp.all(gates >= 0))
175
+ self.assertTrue(jnp.all(gates <= 1))
176
+
177
+ def test_state_persistence(self):
178
+ """Test that state persists across updates."""
179
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
180
+ cell.init_state(batch_size=self.batch_size)
181
+
182
+ # Process sequence and track states
183
+ states = []
184
+ for t in range(10):
185
+ _ = cell.update(self.sequence[t])
186
+ states.append(cell.h.value.copy())
187
+
188
+ # States should evolve over time
189
+ for i in range(1, len(states)):
190
+ self.assertFalse(jnp.allclose(states[i], states[i-1], atol=1e-8))
191
+
192
+ def test_reset_vs_update_gates(self):
193
+ """Test that reset and update gates behave differently."""
194
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
195
+ cell.init_state(batch_size=self.batch_size)
196
+
197
+ # Get gates for the same input
198
+ old_h = cell.h.value
199
+ xh = jnp.concatenate([self.x, old_h], axis=-1)
200
+ r, z = jnp.split(functional.sigmoid(cell.Wrz(xh)), indices_or_sections=2, axis=-1)
201
+
202
+ # Reset and update gates should be different
203
+ self.assertFalse(jnp.allclose(r, z, atol=1e-6))
204
+
205
+ def test_different_initializers(self):
206
+ """Test with different weight initializers."""
207
+ initializers = [
208
+ init.XavierNormal(),
209
+ init.XavierUniform(),
210
+ init.Orthogonal(),
211
+ init.KaimingNormal(),
212
+ ]
213
+
214
+ for w_init in initializers:
215
+ cell = GRUCell(
216
+ num_in=self.num_in,
217
+ num_out=self.num_out,
218
+ w_init=w_init
219
+ )
220
+ cell.init_state(batch_size=self.batch_size)
221
+ output = cell.update(self.x)
222
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
223
+
224
+
225
+ class TestMGUCell(TestRNNCellBase):
226
+ """Comprehensive tests for MGUCell."""
227
+
228
+ def test_basic_forward(self):
229
+ """Test basic forward pass."""
230
+ cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
231
+ cell.init_state(batch_size=self.batch_size)
232
+
233
+ output = cell.update(self.x)
234
+
235
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
236
+ self.assertFalse(jnp.any(jnp.isnan(output)))
237
+
238
+ def test_single_gate_mechanism(self):
239
+ """Test that MGU uses single forget gate."""
240
+ cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
241
+ cell.init_state(batch_size=self.batch_size)
242
+
243
+ # Check that only one gate is computed
244
+ xh = jnp.concatenate([self.x, cell.h.value], axis=-1)
245
+ f = functional.sigmoid(cell.Wf(xh))
246
+
247
+ # Forget gate should be between 0 and 1
248
+ self.assertTrue(jnp.all(f >= 0))
249
+ self.assertTrue(jnp.all(f <= 1))
250
+ self.assertEqual(f.shape, (self.batch_size, self.num_out))
251
+
252
+ def test_parameter_efficiency(self):
253
+ """Test that MGU has fewer parameters than GRU."""
254
+ mgu_cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
255
+ gru_cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
256
+
257
+ # Count parameters - MGU should have fewer
258
+ # MGU has 2 weight matrices (Wf, Wh)
259
+ # GRU has 2 weight matrices but one is double size (Wrz, Wh)
260
+ mgu_param_count = 2 * ((self.num_in + self.num_out) * self.num_out + self.num_out)
261
+ gru_param_count = ((self.num_in + self.num_out) * (self.num_out * 2) + self.num_out * 2) + \
262
+ ((self.num_in + self.num_out) * self.num_out + self.num_out)
263
+
264
+ self.assertLess(mgu_param_count, gru_param_count)
265
+
266
+
267
+ class TestLSTMCell(TestRNNCellBase):
268
+ """Comprehensive tests for LSTMCell."""
269
+
270
+ def test_basic_forward(self):
271
+ """Test basic forward pass."""
272
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
273
+ cell.init_state(batch_size=self.batch_size)
274
+
275
+ output = cell.update(self.x)
276
+
277
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
278
+ self.assertFalse(jnp.any(jnp.isnan(output)))
279
+
280
+ def test_dual_state_mechanism(self):
281
+ """Test that LSTM maintains both hidden and cell states."""
282
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
283
+ cell.init_state(batch_size=self.batch_size)
284
+
285
+ # Check initial states
286
+ self.assertIsNotNone(cell.h)
287
+ self.assertIsNotNone(cell.c)
288
+ self.assertEqual(cell.h.value.shape, (self.batch_size, self.num_out))
289
+ self.assertEqual(cell.c.value.shape, (self.batch_size, self.num_out))
290
+
291
+ # Update and check states change
292
+ h_before = cell.h.value.copy()
293
+ c_before = cell.c.value.copy()
294
+
295
+ _ = cell.update(self.x)
296
+
297
+ self.assertFalse(jnp.allclose(cell.h.value, h_before, atol=1e-8))
298
+ self.assertFalse(jnp.allclose(cell.c.value, c_before, atol=1e-8))
299
+
300
+ def test_forget_gate_bias(self):
301
+ """Test that forget gate has positive bias initialization."""
302
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
303
+ cell.init_state(batch_size=self.batch_size)
304
+
305
+ # Process with zero input to see bias effect
306
+ zero_input = jnp.zeros((self.batch_size, self.num_in))
307
+ xh = jnp.concatenate([zero_input, cell.h.value], axis=-1)
308
+ gates = cell.W(xh)
309
+ _, _, f, _ = jnp.split(gates, indices_or_sections=4, axis=-1)
310
+ f_gate = functional.sigmoid(f + 1.) # Note the +1 bias
311
+
312
+ # Forget gate should be biased towards remembering (> 0.5)
313
+ self.assertTrue(jnp.mean(f_gate) > 0.5)
314
+
315
+ def test_gate_values_range(self):
316
+ """Test that all gates produce values in [0, 1]."""
317
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
318
+ cell.init_state(batch_size=self.batch_size)
319
+
320
+ xh = jnp.concatenate([self.x, cell.h.value], axis=-1)
321
+ i, g, f, o = jnp.split(cell.W(xh), indices_or_sections=4, axis=-1)
322
+
323
+ i_gate = functional.sigmoid(i)
324
+ f_gate = functional.sigmoid(f + 1.)
325
+ o_gate = functional.sigmoid(o)
326
+
327
+ for gate in [i_gate, f_gate, o_gate]:
328
+ self.assertTrue(jnp.all(gate >= 0))
329
+ self.assertTrue(jnp.all(gate <= 1))
330
+
331
+ def test_cell_state_gradient_flow(self):
332
+ """Test gradient flow through cell state."""
333
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
334
+ cell.init_state(batch_size=self.batch_size)
335
+
336
+ def loss_fn(x):
337
+ for t in range(10):
338
+ _ = cell.update(x)
339
+ return jnp.mean(cell.c.value ** 2)
340
+
341
+ grad_fn = jax.grad(loss_fn)
342
+ grad = grad_fn(self.x)
343
+
344
+ self.assertFalse(jnp.any(jnp.isnan(grad)))
345
+ self.assertTrue(jnp.any(grad != 0))
346
+
347
+
348
+ class TestURLSTMCell(TestRNNCellBase):
349
+ """Comprehensive tests for URLSTMCell."""
350
+
351
+ def test_basic_forward(self):
352
+ """Test basic forward pass."""
353
+ cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
354
+ cell.init_state(batch_size=self.batch_size)
355
+
356
+ output = cell.update(self.x)
357
+
358
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
359
+ self.assertFalse(jnp.any(jnp.isnan(output)))
360
+
361
+ def test_untied_bias_mechanism(self):
362
+ """Test the untied bias initialization."""
363
+ cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
364
+ cell.init_state(batch_size=self.batch_size)
365
+
366
+ # Check bias values are initialized
367
+ self.assertIsNotNone(cell.bias.value)
368
+ self.assertEqual(cell.bias.value.shape, (self.num_out,))
369
+
370
+ # Biases should be diverse (not all the same)
371
+ self.assertGreater(jnp.std(cell.bias.value), 0.1)
372
+
373
+ def test_unified_gate_computation(self):
374
+ """Test the unified gate mechanism."""
375
+ cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
376
+ cell.init_state(batch_size=self.batch_size)
377
+
378
+ h, c = cell.h.value, cell.c.value
379
+ xh = jnp.concatenate([self.x, h], axis=-1)
380
+ gates = cell.W(xh)
381
+ f, r, u, o = jnp.split(gates, indices_or_sections=4, axis=-1)
382
+
383
+ f_gate = functional.sigmoid(f + cell.bias.value)
384
+ r_gate = functional.sigmoid(r - cell.bias.value)
385
+
386
+ # Compute unified gate
387
+ g = 2 * r_gate * f_gate + (1 - 2 * r_gate) * f_gate ** 2
388
+
389
+ # Unified gate should be in [0, 1]
390
+ self.assertTrue(jnp.all(g >= 0))
391
+ self.assertTrue(jnp.all(g <= 1))
392
+
393
+ def test_comparison_with_lstm(self):
394
+ """Test that URLSTM behaves differently from standard LSTM."""
395
+ urlstm = URLSTMCell(num_in=self.num_in, num_out=self.num_out, state_init=init.Constant(0.5))
396
+ lstm = LSTMCell(num_in=self.num_in, num_out=self.num_out, state_init=init.Constant(0.5))
397
+
398
+ urlstm.init_state(batch_size=self.batch_size)
399
+ lstm.init_state(batch_size=self.batch_size)
400
+
401
+ # Same input should produce different outputs
402
+ urlstm_out = urlstm.update(self.x)
403
+ lstm_out = lstm.update(self.x)
404
+
405
+ self.assertFalse(jnp.allclose(urlstm_out, lstm_out, atol=1e-4))
406
+
407
+
408
+ class TestRNNCellIntegration(TestRNNCellBase):
409
+ """Integration tests for all RNN cells."""
410
+
411
+ def test_all_cells_compatible_interface(self):
412
+ """Test that all cells have compatible interfaces."""
413
+ cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
414
+
415
+ for CellType in cell_types:
416
+ cell = CellType(num_in=self.num_in, num_out=self.num_out)
417
+
418
+ # Test init_state
419
+ cell.init_state(batch_size=self.batch_size)
420
+
421
+ # Test update
422
+ output = cell.update(self.x)
423
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
424
+
425
+ # Test reset_state
426
+ cell.reset_state(batch_size=16)
427
+
428
+ # Test with new batch size
429
+ x_small = jnp.ones((16, self.num_in))
430
+ output_small = cell.update(x_small)
431
+ self.assertEqual(output_small.shape, (16, self.num_out))
432
+
433
+ def test_sequence_to_sequence(self):
434
+ """Test sequence-to-sequence processing."""
435
+ cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
436
+
437
+ for CellType in cell_types:
438
+ cell = CellType(num_in=self.num_in, num_out=self.num_out)
439
+ cell.init_state(batch_size=self.batch_size)
440
+
441
+ outputs = []
442
+ for t in range(self.sequence_length):
443
+ output = cell.update(self.sequence[t])
444
+ outputs.append(output)
445
+
446
+ outputs = jnp.stack(outputs)
447
+ self.assertEqual(
448
+ outputs.shape,
449
+ (self.sequence_length, self.batch_size, self.num_out)
450
+ )
451
+
452
+ def test_variable_length_sequences(self):
453
+ """Test handling of variable length sequences with masking."""
454
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
455
+ cell.init_state(batch_size=self.batch_size)
456
+
457
+ # Create mask for variable lengths
458
+ lengths = jnp.array([10, 20, 30, 40] * (self.batch_size // 4))
459
+ mask = jnp.arange(self.sequence_length)[:, None] < lengths[None, :]
460
+
461
+ outputs = []
462
+ for t in range(self.sequence_length):
463
+ output = cell.update(self.sequence[t])
464
+ # Apply mask
465
+ output = output * mask[t:t+1].T
466
+ outputs.append(output)
467
+
468
+ outputs = jnp.stack(outputs)
469
+
470
+ # Check that masked positions are zero
471
+ for b in range(self.batch_size):
472
+ length = lengths[b]
473
+ if length < self.sequence_length:
474
+ self.assertTrue(jnp.allclose(outputs[length:, b, :], 0.0))
475
+
476
+ def test_gradient_clipping(self):
477
+ """Test gradient clipping during training."""
478
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
479
+ cell.init_state(batch_size=self.batch_size)
480
+
481
+ def loss_fn(x):
482
+ output = jnp.zeros((self.batch_size, self.num_out))
483
+ for t in range(50): # Long sequence
484
+ output = cell.update(x * (t + 1)) # Amplify input
485
+ return jnp.mean(output ** 2)
486
+
487
+ grad_fn = jax.grad(loss_fn)
488
+ grad = grad_fn(self.x)
489
+
490
+ # Gradients should not explode
491
+ self.assertFalse(jnp.any(jnp.isnan(grad)))
492
+ self.assertFalse(jnp.any(jnp.isinf(grad)))
493
+ self.assertLess(jnp.max(jnp.abs(grad)), 1e6)
494
+
495
+
496
+ class TestRNNCellEdgeCases(TestRNNCellBase):
497
+ """Edge case tests for RNN cells."""
498
+
499
+ def test_single_sample(self):
500
+ """Test with batch size of 1."""
501
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
502
+ cell.init_state(batch_size=1)
503
+
504
+ x = jnp.ones((1, self.num_in))
505
+ output = cell.update(x)
506
+ self.assertEqual(output.shape, (1, self.num_out))
507
+
508
+ def test_zero_input(self):
509
+ """Test with zero inputs."""
510
+ cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
511
+
512
+ for CellType in cell_types:
513
+ cell = CellType(num_in=self.num_in, num_out=self.num_out)
514
+ cell.init_state(batch_size=self.batch_size)
515
+
516
+ zero_input = jnp.zeros((self.batch_size, self.num_in))
517
+ output = cell.update(zero_input)
518
+
519
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
520
+ self.assertFalse(jnp.any(jnp.isnan(output)))
521
+
522
+ def test_large_input_values(self):
523
+ """Test with large input values."""
524
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
525
+ cell.init_state(batch_size=self.batch_size)
526
+
527
+ large_input = jnp.ones((self.batch_size, self.num_in)) * 100
528
+ output = cell.update(large_input)
529
+
530
+ # Should handle large inputs gracefully (sigmoid saturation)
531
+ self.assertFalse(jnp.any(jnp.isnan(output)))
532
+ self.assertFalse(jnp.any(jnp.isinf(output)))
533
+
534
+ def test_very_long_sequence(self):
535
+ """Test with very long sequences."""
536
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
537
+ cell.init_state(batch_size=4) # Smaller batch for memory
538
+
539
+ long_sequence = jnp.ones((1000, 4, self.num_in))
540
+
541
+ final_output = None
542
+ for t in range(1000):
543
+ final_output = cell.update(long_sequence[t])
544
+
545
+ # Should not have numerical issues even after long sequence
546
+ self.assertFalse(jnp.any(jnp.isnan(final_output)))
547
+ self.assertFalse(jnp.any(jnp.isinf(final_output)))
548
+
549
+ def test_dimension_mismatch_error(self):
550
+ """Test that dimension mismatches raise appropriate errors."""
551
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
552
+ cell.init_state(batch_size=self.batch_size)
553
+
554
+ # Wrong input dimension should raise error
555
+ wrong_input = jnp.ones((self.batch_size, self.num_in + 5))
556
+
557
+ with self.assertRaises(Exception):
558
+ _ = cell.update(wrong_input)
559
+
560
+
561
+ class TestRNNCellProperties(TestRNNCellBase):
562
+ """Test cell properties and attributes."""
563
+
564
+ def test_cell_attributes(self):
565
+ """Test that cells have correct attributes."""
566
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
567
+
568
+ self.assertEqual(cell.num_in, self.num_in)
569
+ self.assertEqual(cell.num_out, self.num_out)
570
+ self.assertEqual(cell.in_size, (self.num_in,))
571
+ self.assertEqual(cell.out_size, (self.num_out,))
572
+
573
+ def test_inheritance_structure(self):
574
+ """Test that all cells inherit from RNNCell."""
575
+ cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
576
+
577
+ for CellType in cell_types:
578
+ cell = CellType(num_in=self.num_in, num_out=self.num_out)
579
+ self.assertIsInstance(cell, RNNCell)
580
+
581
+ def test_docstring_presence(self):
582
+ """Test that all cells have proper docstrings."""
583
+ cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
584
+
585
+ for CellType in cell_types:
586
+ self.assertIsNotNone(CellType.__doc__)
587
+ self.assertIn("Examples", CellType.__doc__)
588
+ self.assertIn("Parameters", CellType.__doc__)
589
+ self.assertIn(">>>", CellType.__doc__)
590
+
591
+
592
+ if __name__ == '__main__':
593
593
  unittest.main()