brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,593 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Comprehensive tests for RNN cell implementations."""
17
+
18
+ import unittest
19
+ from typing import Type
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
24
+
25
+ import brainstate
26
+ import brainstate.nn as nn
27
+ from brainstate.nn import RNNCell, ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell
28
+ from brainstate.nn import init as init
29
+ from brainstate.nn import _activations as functional
30
+
31
+
32
+ class TestRNNCellBase(unittest.TestCase):
33
+ """Base test class for all RNN cell implementations."""
34
+
35
+ def setUp(self):
36
+ """Set up test fixtures."""
37
+ self.num_in = 10
38
+ self.num_out = 20
39
+ self.batch_size = 32
40
+ self.sequence_length = 100
41
+ self.seed = 42
42
+
43
+ # Initialize random inputs
44
+ key = jax.random.PRNGKey(self.seed)
45
+ self.x = jax.random.normal(key, (self.batch_size, self.num_in))
46
+ self.sequence = jax.random.normal(
47
+ key, (self.sequence_length, self.batch_size, self.num_in)
48
+ )
49
+
50
+
51
+ class TestVanillaRNNCell(TestRNNCellBase):
52
+ """Comprehensive tests for VanillaRNNCell."""
53
+
54
+ def test_basic_forward(self):
55
+ """Test basic forward pass."""
56
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
57
+ cell.init_state(batch_size=self.batch_size)
58
+
59
+ output = cell.update(self.x)
60
+
61
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
62
+ self.assertFalse(jnp.any(jnp.isnan(output)))
63
+ self.assertFalse(jnp.any(jnp.isinf(output)))
64
+
65
+ def test_sequence_processing(self):
66
+ """Test processing a sequence of inputs."""
67
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
68
+ cell.init_state(batch_size=self.batch_size)
69
+
70
+ outputs = []
71
+ for t in range(self.sequence_length):
72
+ output = cell.update(self.sequence[t])
73
+ outputs.append(output)
74
+
75
+ outputs = jnp.stack(outputs)
76
+ self.assertEqual(outputs.shape, (self.sequence_length, self.batch_size, self.num_out))
77
+ self.assertFalse(jnp.any(jnp.isnan(outputs)))
78
+
79
+ def test_state_reset(self):
80
+ """Test state reset functionality."""
81
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
82
+ cell.init_state(batch_size=self.batch_size)
83
+
84
+ # Process some input
85
+ _ = cell.update(self.x)
86
+ state_before = cell.h.value.copy()
87
+
88
+ # Reset state
89
+ cell.reset_state(batch_size=self.batch_size)
90
+ state_after = cell.h.value.copy()
91
+
92
+ # States should be different (unless randomly the same, which is unlikely)
93
+ self.assertFalse(jnp.allclose(state_before, state_after, atol=1e-6))
94
+
95
+ def test_different_batch_sizes(self):
96
+ """Test with different batch sizes."""
97
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
98
+
99
+ for batch_size in [1, 16, 64]:
100
+ cell.init_state(batch_size=batch_size)
101
+ x = jnp.ones((batch_size, self.num_in))
102
+ output = cell.update(x)
103
+ self.assertEqual(output.shape, (batch_size, self.num_out))
104
+
105
+ def test_activation_functions(self):
106
+ """Test different activation functions."""
107
+ activations = ['relu', 'tanh', 'sigmoid', 'gelu']
108
+
109
+ for activation in activations:
110
+ cell = ValinaRNNCell(
111
+ num_in=self.num_in,
112
+ num_out=self.num_out,
113
+ activation=activation
114
+ )
115
+ cell.init_state(batch_size=self.batch_size)
116
+ output = cell.update(self.x)
117
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
118
+ self.assertFalse(jnp.any(jnp.isnan(output)))
119
+
120
+ def test_custom_initializers(self):
121
+ """Test custom weight and state initializers."""
122
+ cell = ValinaRNNCell(
123
+ num_in=self.num_in,
124
+ num_out=self.num_out,
125
+ w_init=init.Orthogonal(),
126
+ b_init=init.Constant(0.1),
127
+ state_init=init.Normal(0.01)
128
+ )
129
+ cell.init_state(batch_size=self.batch_size)
130
+ output = cell.update(self.x)
131
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
132
+
133
+ def test_gradient_flow(self):
134
+ """Test gradient flow through the cell."""
135
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
136
+ cell.init_state(batch_size=self.batch_size)
137
+
138
+ def loss_fn(x):
139
+ output = cell.update(x)
140
+ return jnp.mean(output ** 2)
141
+
142
+ grad_fn = jax.grad(loss_fn)
143
+ grad = grad_fn(self.x)
144
+
145
+ self.assertEqual(grad.shape, self.x.shape)
146
+ self.assertFalse(jnp.any(jnp.isnan(grad)))
147
+ self.assertTrue(jnp.any(grad != 0)) # Gradients should be non-zero
148
+
149
+
150
+ class TestGRUCell(TestRNNCellBase):
151
+ """Comprehensive tests for GRUCell."""
152
+
153
+ def test_basic_forward(self):
154
+ """Test basic forward pass."""
155
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
156
+ cell.init_state(batch_size=self.batch_size)
157
+
158
+ output = cell.update(self.x)
159
+
160
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
161
+ self.assertFalse(jnp.any(jnp.isnan(output)))
162
+
163
+ def test_gating_mechanism(self):
164
+ """Test that gating values are in valid range."""
165
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
166
+ cell.init_state(batch_size=self.batch_size)
167
+
168
+ # Access internal computation
169
+ old_h = cell.h.value
170
+ xh = jnp.concatenate([self.x, old_h], axis=-1)
171
+ gates = functional.sigmoid(cell.Wrz(xh))
172
+
173
+ # Gates should be between 0 and 1
174
+ self.assertTrue(jnp.all(gates >= 0))
175
+ self.assertTrue(jnp.all(gates <= 1))
176
+
177
+ def test_state_persistence(self):
178
+ """Test that state persists across updates."""
179
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
180
+ cell.init_state(batch_size=self.batch_size)
181
+
182
+ # Process sequence and track states
183
+ states = []
184
+ for t in range(10):
185
+ _ = cell.update(self.sequence[t])
186
+ states.append(cell.h.value.copy())
187
+
188
+ # States should evolve over time
189
+ for i in range(1, len(states)):
190
+ self.assertFalse(jnp.allclose(states[i], states[i-1], atol=1e-8))
191
+
192
+ def test_reset_vs_update_gates(self):
193
+ """Test that reset and update gates behave differently."""
194
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
195
+ cell.init_state(batch_size=self.batch_size)
196
+
197
+ # Get gates for the same input
198
+ old_h = cell.h.value
199
+ xh = jnp.concatenate([self.x, old_h], axis=-1)
200
+ r, z = jnp.split(functional.sigmoid(cell.Wrz(xh)), indices_or_sections=2, axis=-1)
201
+
202
+ # Reset and update gates should be different
203
+ self.assertFalse(jnp.allclose(r, z, atol=1e-6))
204
+
205
+ def test_different_initializers(self):
206
+ """Test with different weight initializers."""
207
+ initializers = [
208
+ init.XavierNormal(),
209
+ init.XavierUniform(),
210
+ init.Orthogonal(),
211
+ init.KaimingNormal(),
212
+ ]
213
+
214
+ for w_init in initializers:
215
+ cell = GRUCell(
216
+ num_in=self.num_in,
217
+ num_out=self.num_out,
218
+ w_init=w_init
219
+ )
220
+ cell.init_state(batch_size=self.batch_size)
221
+ output = cell.update(self.x)
222
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
223
+
224
+
225
+ class TestMGUCell(TestRNNCellBase):
226
+ """Comprehensive tests for MGUCell."""
227
+
228
+ def test_basic_forward(self):
229
+ """Test basic forward pass."""
230
+ cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
231
+ cell.init_state(batch_size=self.batch_size)
232
+
233
+ output = cell.update(self.x)
234
+
235
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
236
+ self.assertFalse(jnp.any(jnp.isnan(output)))
237
+
238
+ def test_single_gate_mechanism(self):
239
+ """Test that MGU uses single forget gate."""
240
+ cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
241
+ cell.init_state(batch_size=self.batch_size)
242
+
243
+ # Check that only one gate is computed
244
+ xh = jnp.concatenate([self.x, cell.h.value], axis=-1)
245
+ f = functional.sigmoid(cell.Wf(xh))
246
+
247
+ # Forget gate should be between 0 and 1
248
+ self.assertTrue(jnp.all(f >= 0))
249
+ self.assertTrue(jnp.all(f <= 1))
250
+ self.assertEqual(f.shape, (self.batch_size, self.num_out))
251
+
252
+ def test_parameter_efficiency(self):
253
+ """Test that MGU has fewer parameters than GRU."""
254
+ mgu_cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
255
+ gru_cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
256
+
257
+ # Count parameters - MGU should have fewer
258
+ # MGU has 2 weight matrices (Wf, Wh)
259
+ # GRU has 2 weight matrices but one is double size (Wrz, Wh)
260
+ mgu_param_count = 2 * ((self.num_in + self.num_out) * self.num_out + self.num_out)
261
+ gru_param_count = ((self.num_in + self.num_out) * (self.num_out * 2) + self.num_out * 2) + \
262
+ ((self.num_in + self.num_out) * self.num_out + self.num_out)
263
+
264
+ self.assertLess(mgu_param_count, gru_param_count)
265
+
266
+
267
+ class TestLSTMCell(TestRNNCellBase):
268
+ """Comprehensive tests for LSTMCell."""
269
+
270
+ def test_basic_forward(self):
271
+ """Test basic forward pass."""
272
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
273
+ cell.init_state(batch_size=self.batch_size)
274
+
275
+ output = cell.update(self.x)
276
+
277
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
278
+ self.assertFalse(jnp.any(jnp.isnan(output)))
279
+
280
+ def test_dual_state_mechanism(self):
281
+ """Test that LSTM maintains both hidden and cell states."""
282
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
283
+ cell.init_state(batch_size=self.batch_size)
284
+
285
+ # Check initial states
286
+ self.assertIsNotNone(cell.h)
287
+ self.assertIsNotNone(cell.c)
288
+ self.assertEqual(cell.h.value.shape, (self.batch_size, self.num_out))
289
+ self.assertEqual(cell.c.value.shape, (self.batch_size, self.num_out))
290
+
291
+ # Update and check states change
292
+ h_before = cell.h.value.copy()
293
+ c_before = cell.c.value.copy()
294
+
295
+ _ = cell.update(self.x)
296
+
297
+ self.assertFalse(jnp.allclose(cell.h.value, h_before, atol=1e-8))
298
+ self.assertFalse(jnp.allclose(cell.c.value, c_before, atol=1e-8))
299
+
300
+ def test_forget_gate_bias(self):
301
+ """Test that forget gate has positive bias initialization."""
302
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
303
+ cell.init_state(batch_size=self.batch_size)
304
+
305
+ # Process with zero input to see bias effect
306
+ zero_input = jnp.zeros((self.batch_size, self.num_in))
307
+ xh = jnp.concatenate([zero_input, cell.h.value], axis=-1)
308
+ gates = cell.W(xh)
309
+ _, _, f, _ = jnp.split(gates, indices_or_sections=4, axis=-1)
310
+ f_gate = functional.sigmoid(f + 1.) # Note the +1 bias
311
+
312
+ # Forget gate should be biased towards remembering (> 0.5)
313
+ self.assertTrue(jnp.mean(f_gate) > 0.5)
314
+
315
+ def test_gate_values_range(self):
316
+ """Test that all gates produce values in [0, 1]."""
317
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
318
+ cell.init_state(batch_size=self.batch_size)
319
+
320
+ xh = jnp.concatenate([self.x, cell.h.value], axis=-1)
321
+ i, g, f, o = jnp.split(cell.W(xh), indices_or_sections=4, axis=-1)
322
+
323
+ i_gate = functional.sigmoid(i)
324
+ f_gate = functional.sigmoid(f + 1.)
325
+ o_gate = functional.sigmoid(o)
326
+
327
+ for gate in [i_gate, f_gate, o_gate]:
328
+ self.assertTrue(jnp.all(gate >= 0))
329
+ self.assertTrue(jnp.all(gate <= 1))
330
+
331
+ def test_cell_state_gradient_flow(self):
332
+ """Test gradient flow through cell state."""
333
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
334
+ cell.init_state(batch_size=self.batch_size)
335
+
336
+ def loss_fn(x):
337
+ for t in range(10):
338
+ _ = cell.update(x)
339
+ return jnp.mean(cell.c.value ** 2)
340
+
341
+ grad_fn = jax.grad(loss_fn)
342
+ grad = grad_fn(self.x)
343
+
344
+ self.assertFalse(jnp.any(jnp.isnan(grad)))
345
+ self.assertTrue(jnp.any(grad != 0))
346
+
347
+
348
+ class TestURLSTMCell(TestRNNCellBase):
349
+ """Comprehensive tests for URLSTMCell."""
350
+
351
+ def test_basic_forward(self):
352
+ """Test basic forward pass."""
353
+ cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
354
+ cell.init_state(batch_size=self.batch_size)
355
+
356
+ output = cell.update(self.x)
357
+
358
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
359
+ self.assertFalse(jnp.any(jnp.isnan(output)))
360
+
361
+ def test_untied_bias_mechanism(self):
362
+ """Test the untied bias initialization."""
363
+ cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
364
+ cell.init_state(batch_size=self.batch_size)
365
+
366
+ # Check bias values are initialized
367
+ self.assertIsNotNone(cell.bias.value)
368
+ self.assertEqual(cell.bias.value.shape, (self.num_out,))
369
+
370
+ # Biases should be diverse (not all the same)
371
+ self.assertGreater(jnp.std(cell.bias.value), 0.1)
372
+
373
+ def test_unified_gate_computation(self):
374
+ """Test the unified gate mechanism."""
375
+ cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
376
+ cell.init_state(batch_size=self.batch_size)
377
+
378
+ h, c = cell.h.value, cell.c.value
379
+ xh = jnp.concatenate([self.x, h], axis=-1)
380
+ gates = cell.W(xh)
381
+ f, r, u, o = jnp.split(gates, indices_or_sections=4, axis=-1)
382
+
383
+ f_gate = functional.sigmoid(f + cell.bias.value)
384
+ r_gate = functional.sigmoid(r - cell.bias.value)
385
+
386
+ # Compute unified gate
387
+ g = 2 * r_gate * f_gate + (1 - 2 * r_gate) * f_gate ** 2
388
+
389
+ # Unified gate should be in [0, 1]
390
+ self.assertTrue(jnp.all(g >= 0))
391
+ self.assertTrue(jnp.all(g <= 1))
392
+
393
+ def test_comparison_with_lstm(self):
394
+ """Test that URLSTM behaves differently from standard LSTM."""
395
+ urlstm = URLSTMCell(num_in=self.num_in, num_out=self.num_out, state_init=init.Constant(0.5))
396
+ lstm = LSTMCell(num_in=self.num_in, num_out=self.num_out, state_init=init.Constant(0.5))
397
+
398
+ urlstm.init_state(batch_size=self.batch_size)
399
+ lstm.init_state(batch_size=self.batch_size)
400
+
401
+ # Same input should produce different outputs
402
+ urlstm_out = urlstm.update(self.x)
403
+ lstm_out = lstm.update(self.x)
404
+
405
+ self.assertFalse(jnp.allclose(urlstm_out, lstm_out, atol=1e-4))
406
+
407
+
408
+ class TestRNNCellIntegration(TestRNNCellBase):
409
+ """Integration tests for all RNN cells."""
410
+
411
+ def test_all_cells_compatible_interface(self):
412
+ """Test that all cells have compatible interfaces."""
413
+ cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
414
+
415
+ for CellType in cell_types:
416
+ cell = CellType(num_in=self.num_in, num_out=self.num_out)
417
+
418
+ # Test init_state
419
+ cell.init_state(batch_size=self.batch_size)
420
+
421
+ # Test update
422
+ output = cell.update(self.x)
423
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
424
+
425
+ # Test reset_state
426
+ cell.reset_state(batch_size=16)
427
+
428
+ # Test with new batch size
429
+ x_small = jnp.ones((16, self.num_in))
430
+ output_small = cell.update(x_small)
431
+ self.assertEqual(output_small.shape, (16, self.num_out))
432
+
433
+ def test_sequence_to_sequence(self):
434
+ """Test sequence-to-sequence processing."""
435
+ cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
436
+
437
+ for CellType in cell_types:
438
+ cell = CellType(num_in=self.num_in, num_out=self.num_out)
439
+ cell.init_state(batch_size=self.batch_size)
440
+
441
+ outputs = []
442
+ for t in range(self.sequence_length):
443
+ output = cell.update(self.sequence[t])
444
+ outputs.append(output)
445
+
446
+ outputs = jnp.stack(outputs)
447
+ self.assertEqual(
448
+ outputs.shape,
449
+ (self.sequence_length, self.batch_size, self.num_out)
450
+ )
451
+
452
+ def test_variable_length_sequences(self):
453
+ """Test handling of variable length sequences with masking."""
454
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
455
+ cell.init_state(batch_size=self.batch_size)
456
+
457
+ # Create mask for variable lengths
458
+ lengths = jnp.array([10, 20, 30, 40] * (self.batch_size // 4))
459
+ mask = jnp.arange(self.sequence_length)[:, None] < lengths[None, :]
460
+
461
+ outputs = []
462
+ for t in range(self.sequence_length):
463
+ output = cell.update(self.sequence[t])
464
+ # Apply mask
465
+ output = output * mask[t:t+1].T
466
+ outputs.append(output)
467
+
468
+ outputs = jnp.stack(outputs)
469
+
470
+ # Check that masked positions are zero
471
+ for b in range(self.batch_size):
472
+ length = lengths[b]
473
+ if length < self.sequence_length:
474
+ self.assertTrue(jnp.allclose(outputs[length:, b, :], 0.0))
475
+
476
+ def test_gradient_clipping(self):
477
+ """Test gradient clipping during training."""
478
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
479
+ cell.init_state(batch_size=self.batch_size)
480
+
481
+ def loss_fn(x):
482
+ output = jnp.zeros((self.batch_size, self.num_out))
483
+ for t in range(50): # Long sequence
484
+ output = cell.update(x * (t + 1)) # Amplify input
485
+ return jnp.mean(output ** 2)
486
+
487
+ grad_fn = jax.grad(loss_fn)
488
+ grad = grad_fn(self.x)
489
+
490
+ # Gradients should not explode
491
+ self.assertFalse(jnp.any(jnp.isnan(grad)))
492
+ self.assertFalse(jnp.any(jnp.isinf(grad)))
493
+ self.assertLess(jnp.max(jnp.abs(grad)), 1e6)
494
+
495
+
496
+ class TestRNNCellEdgeCases(TestRNNCellBase):
497
+ """Edge case tests for RNN cells."""
498
+
499
+ def test_single_sample(self):
500
+ """Test with batch size of 1."""
501
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
502
+ cell.init_state(batch_size=1)
503
+
504
+ x = jnp.ones((1, self.num_in))
505
+ output = cell.update(x)
506
+ self.assertEqual(output.shape, (1, self.num_out))
507
+
508
+ def test_zero_input(self):
509
+ """Test with zero inputs."""
510
+ cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
511
+
512
+ for CellType in cell_types:
513
+ cell = CellType(num_in=self.num_in, num_out=self.num_out)
514
+ cell.init_state(batch_size=self.batch_size)
515
+
516
+ zero_input = jnp.zeros((self.batch_size, self.num_in))
517
+ output = cell.update(zero_input)
518
+
519
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
520
+ self.assertFalse(jnp.any(jnp.isnan(output)))
521
+
522
+ def test_large_input_values(self):
523
+ """Test with large input values."""
524
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
525
+ cell.init_state(batch_size=self.batch_size)
526
+
527
+ large_input = jnp.ones((self.batch_size, self.num_in)) * 100
528
+ output = cell.update(large_input)
529
+
530
+ # Should handle large inputs gracefully (sigmoid saturation)
531
+ self.assertFalse(jnp.any(jnp.isnan(output)))
532
+ self.assertFalse(jnp.any(jnp.isinf(output)))
533
+
534
+ def test_very_long_sequence(self):
535
+ """Test with very long sequences."""
536
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
537
+ cell.init_state(batch_size=4) # Smaller batch for memory
538
+
539
+ long_sequence = jnp.ones((1000, 4, self.num_in))
540
+
541
+ final_output = None
542
+ for t in range(1000):
543
+ final_output = cell.update(long_sequence[t])
544
+
545
+ # Should not have numerical issues even after long sequence
546
+ self.assertFalse(jnp.any(jnp.isnan(final_output)))
547
+ self.assertFalse(jnp.any(jnp.isinf(final_output)))
548
+
549
+ def test_dimension_mismatch_error(self):
550
+ """Test that dimension mismatches raise appropriate errors."""
551
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
552
+ cell.init_state(batch_size=self.batch_size)
553
+
554
+ # Wrong input dimension should raise error
555
+ wrong_input = jnp.ones((self.batch_size, self.num_in + 5))
556
+
557
+ with self.assertRaises(Exception):
558
+ _ = cell.update(wrong_input)
559
+
560
+
561
+ class TestRNNCellProperties(TestRNNCellBase):
562
+ """Test cell properties and attributes."""
563
+
564
+ def test_cell_attributes(self):
565
+ """Test that cells have correct attributes."""
566
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
567
+
568
+ self.assertEqual(cell.num_in, self.num_in)
569
+ self.assertEqual(cell.num_out, self.num_out)
570
+ self.assertEqual(cell.in_size, (self.num_in,))
571
+ self.assertEqual(cell.out_size, (self.num_out,))
572
+
573
+ def test_inheritance_structure(self):
574
+ """Test that all cells inherit from RNNCell."""
575
+ cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
576
+
577
+ for CellType in cell_types:
578
+ cell = CellType(num_in=self.num_in, num_out=self.num_out)
579
+ self.assertIsInstance(cell, RNNCell)
580
+
581
+ def test_docstring_presence(self):
582
+ """Test that all cells have proper docstrings."""
583
+ cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
584
+
585
+ for CellType in cell_types:
586
+ self.assertIsNotNone(CellType.__doc__)
587
+ self.assertIn("Examples", CellType.__doc__)
588
+ self.assertIn("Parameters", CellType.__doc__)
589
+ self.assertIn(">>>", CellType.__doc__)
590
+
591
+
592
+ if __name__ == '__main__':
593
+ unittest.main()