brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,593 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Comprehensive tests for RNN cell implementations."""
17
+
18
+ import unittest
19
+ from typing import Type
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
24
+
25
+ import brainstate
26
+ import brainstate.nn as nn
27
+ from brainstate.nn import RNNCell, ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell
28
+ from brainstate.nn import init as init
29
+ from brainstate.nn import _activations as functional
30
+
31
+
32
+ class TestRNNCellBase(unittest.TestCase):
33
+ """Base test class for all RNN cell implementations."""
34
+
35
+ def setUp(self):
36
+ """Set up test fixtures."""
37
+ self.num_in = 10
38
+ self.num_out = 20
39
+ self.batch_size = 32
40
+ self.sequence_length = 100
41
+ self.seed = 42
42
+
43
+ # Initialize random inputs
44
+ key = jax.random.PRNGKey(self.seed)
45
+ self.x = jax.random.normal(key, (self.batch_size, self.num_in))
46
+ self.sequence = jax.random.normal(
47
+ key, (self.sequence_length, self.batch_size, self.num_in)
48
+ )
49
+
50
+
51
+ class TestVanillaRNNCell(TestRNNCellBase):
52
+ """Comprehensive tests for VanillaRNNCell."""
53
+
54
+ def test_basic_forward(self):
55
+ """Test basic forward pass."""
56
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
57
+ cell.init_state(batch_size=self.batch_size)
58
+
59
+ output = cell.update(self.x)
60
+
61
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
62
+ self.assertFalse(jnp.any(jnp.isnan(output)))
63
+ self.assertFalse(jnp.any(jnp.isinf(output)))
64
+
65
+ def test_sequence_processing(self):
66
+ """Test processing a sequence of inputs."""
67
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
68
+ cell.init_state(batch_size=self.batch_size)
69
+
70
+ outputs = []
71
+ for t in range(self.sequence_length):
72
+ output = cell.update(self.sequence[t])
73
+ outputs.append(output)
74
+
75
+ outputs = jnp.stack(outputs)
76
+ self.assertEqual(outputs.shape, (self.sequence_length, self.batch_size, self.num_out))
77
+ self.assertFalse(jnp.any(jnp.isnan(outputs)))
78
+
79
+ def test_state_reset(self):
80
+ """Test state reset functionality."""
81
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
82
+ cell.init_state(batch_size=self.batch_size)
83
+
84
+ # Process some input
85
+ _ = cell.update(self.x)
86
+ state_before = cell.h.value.copy()
87
+
88
+ # Reset state
89
+ cell.reset_state(batch_size=self.batch_size)
90
+ state_after = cell.h.value.copy()
91
+
92
+ # States should be different (unless randomly the same, which is unlikely)
93
+ self.assertFalse(jnp.allclose(state_before, state_after, atol=1e-6))
94
+
95
+ def test_different_batch_sizes(self):
96
+ """Test with different batch sizes."""
97
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
98
+
99
+ for batch_size in [1, 16, 64]:
100
+ cell.init_state(batch_size=batch_size)
101
+ x = jnp.ones((batch_size, self.num_in))
102
+ output = cell.update(x)
103
+ self.assertEqual(output.shape, (batch_size, self.num_out))
104
+
105
+ def test_activation_functions(self):
106
+ """Test different activation functions."""
107
+ activations = ['relu', 'tanh', 'sigmoid', 'gelu']
108
+
109
+ for activation in activations:
110
+ cell = ValinaRNNCell(
111
+ num_in=self.num_in,
112
+ num_out=self.num_out,
113
+ activation=activation
114
+ )
115
+ cell.init_state(batch_size=self.batch_size)
116
+ output = cell.update(self.x)
117
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
118
+ self.assertFalse(jnp.any(jnp.isnan(output)))
119
+
120
+ def test_custom_initializers(self):
121
+ """Test custom weight and state initializers."""
122
+ cell = ValinaRNNCell(
123
+ num_in=self.num_in,
124
+ num_out=self.num_out,
125
+ w_init=init.Orthogonal(),
126
+ b_init=init.Constant(0.1),
127
+ state_init=init.Normal(0.01)
128
+ )
129
+ cell.init_state(batch_size=self.batch_size)
130
+ output = cell.update(self.x)
131
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
132
+
133
+ def test_gradient_flow(self):
134
+ """Test gradient flow through the cell."""
135
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
136
+ cell.init_state(batch_size=self.batch_size)
137
+
138
+ def loss_fn(x):
139
+ output = cell.update(x)
140
+ return jnp.mean(output ** 2)
141
+
142
+ grad_fn = jax.grad(loss_fn)
143
+ grad = grad_fn(self.x)
144
+
145
+ self.assertEqual(grad.shape, self.x.shape)
146
+ self.assertFalse(jnp.any(jnp.isnan(grad)))
147
+ self.assertTrue(jnp.any(grad != 0)) # Gradients should be non-zero
148
+
149
+
150
+ class TestGRUCell(TestRNNCellBase):
151
+ """Comprehensive tests for GRUCell."""
152
+
153
+ def test_basic_forward(self):
154
+ """Test basic forward pass."""
155
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
156
+ cell.init_state(batch_size=self.batch_size)
157
+
158
+ output = cell.update(self.x)
159
+
160
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
161
+ self.assertFalse(jnp.any(jnp.isnan(output)))
162
+
163
+ def test_gating_mechanism(self):
164
+ """Test that gating values are in valid range."""
165
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
166
+ cell.init_state(batch_size=self.batch_size)
167
+
168
+ # Access internal computation
169
+ old_h = cell.h.value
170
+ xh = jnp.concatenate([self.x, old_h], axis=-1)
171
+ gates = functional.sigmoid(cell.Wrz(xh))
172
+
173
+ # Gates should be between 0 and 1
174
+ self.assertTrue(jnp.all(gates >= 0))
175
+ self.assertTrue(jnp.all(gates <= 1))
176
+
177
+ def test_state_persistence(self):
178
+ """Test that state persists across updates."""
179
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
180
+ cell.init_state(batch_size=self.batch_size)
181
+
182
+ # Process sequence and track states
183
+ states = []
184
+ for t in range(10):
185
+ _ = cell.update(self.sequence[t])
186
+ states.append(cell.h.value.copy())
187
+
188
+ # States should evolve over time
189
+ for i in range(1, len(states)):
190
+ self.assertFalse(jnp.allclose(states[i], states[i-1], atol=1e-8))
191
+
192
+ def test_reset_vs_update_gates(self):
193
+ """Test that reset and update gates behave differently."""
194
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
195
+ cell.init_state(batch_size=self.batch_size)
196
+
197
+ # Get gates for the same input
198
+ old_h = cell.h.value
199
+ xh = jnp.concatenate([self.x, old_h], axis=-1)
200
+ r, z = jnp.split(functional.sigmoid(cell.Wrz(xh)), indices_or_sections=2, axis=-1)
201
+
202
+ # Reset and update gates should be different
203
+ self.assertFalse(jnp.allclose(r, z, atol=1e-6))
204
+
205
+ def test_different_initializers(self):
206
+ """Test with different weight initializers."""
207
+ initializers = [
208
+ init.XavierNormal(),
209
+ init.XavierUniform(),
210
+ init.Orthogonal(),
211
+ init.KaimingNormal(),
212
+ ]
213
+
214
+ for w_init in initializers:
215
+ cell = GRUCell(
216
+ num_in=self.num_in,
217
+ num_out=self.num_out,
218
+ w_init=w_init
219
+ )
220
+ cell.init_state(batch_size=self.batch_size)
221
+ output = cell.update(self.x)
222
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
223
+
224
+
225
+ class TestMGUCell(TestRNNCellBase):
226
+ """Comprehensive tests for MGUCell."""
227
+
228
+ def test_basic_forward(self):
229
+ """Test basic forward pass."""
230
+ cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
231
+ cell.init_state(batch_size=self.batch_size)
232
+
233
+ output = cell.update(self.x)
234
+
235
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
236
+ self.assertFalse(jnp.any(jnp.isnan(output)))
237
+
238
+ def test_single_gate_mechanism(self):
239
+ """Test that MGU uses single forget gate."""
240
+ cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
241
+ cell.init_state(batch_size=self.batch_size)
242
+
243
+ # Check that only one gate is computed
244
+ xh = jnp.concatenate([self.x, cell.h.value], axis=-1)
245
+ f = functional.sigmoid(cell.Wf(xh))
246
+
247
+ # Forget gate should be between 0 and 1
248
+ self.assertTrue(jnp.all(f >= 0))
249
+ self.assertTrue(jnp.all(f <= 1))
250
+ self.assertEqual(f.shape, (self.batch_size, self.num_out))
251
+
252
+ def test_parameter_efficiency(self):
253
+ """Test that MGU has fewer parameters than GRU."""
254
+ mgu_cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
255
+ gru_cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
256
+
257
+ # Count parameters - MGU should have fewer
258
+ # MGU has 2 weight matrices (Wf, Wh)
259
+ # GRU has 2 weight matrices but one is double size (Wrz, Wh)
260
+ mgu_param_count = 2 * ((self.num_in + self.num_out) * self.num_out + self.num_out)
261
+ gru_param_count = ((self.num_in + self.num_out) * (self.num_out * 2) + self.num_out * 2) + \
262
+ ((self.num_in + self.num_out) * self.num_out + self.num_out)
263
+
264
+ self.assertLess(mgu_param_count, gru_param_count)
265
+
266
+
267
+ class TestLSTMCell(TestRNNCellBase):
268
+ """Comprehensive tests for LSTMCell."""
269
+
270
+ def test_basic_forward(self):
271
+ """Test basic forward pass."""
272
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
273
+ cell.init_state(batch_size=self.batch_size)
274
+
275
+ output = cell.update(self.x)
276
+
277
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
278
+ self.assertFalse(jnp.any(jnp.isnan(output)))
279
+
280
+ def test_dual_state_mechanism(self):
281
+ """Test that LSTM maintains both hidden and cell states."""
282
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
283
+ cell.init_state(batch_size=self.batch_size)
284
+
285
+ # Check initial states
286
+ self.assertIsNotNone(cell.h)
287
+ self.assertIsNotNone(cell.c)
288
+ self.assertEqual(cell.h.value.shape, (self.batch_size, self.num_out))
289
+ self.assertEqual(cell.c.value.shape, (self.batch_size, self.num_out))
290
+
291
+ # Update and check states change
292
+ h_before = cell.h.value.copy()
293
+ c_before = cell.c.value.copy()
294
+
295
+ _ = cell.update(self.x)
296
+
297
+ self.assertFalse(jnp.allclose(cell.h.value, h_before, atol=1e-8))
298
+ self.assertFalse(jnp.allclose(cell.c.value, c_before, atol=1e-8))
299
+
300
+ def test_forget_gate_bias(self):
301
+ """Test that forget gate has positive bias initialization."""
302
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
303
+ cell.init_state(batch_size=self.batch_size)
304
+
305
+ # Process with zero input to see bias effect
306
+ zero_input = jnp.zeros((self.batch_size, self.num_in))
307
+ xh = jnp.concatenate([zero_input, cell.h.value], axis=-1)
308
+ gates = cell.W(xh)
309
+ _, _, f, _ = jnp.split(gates, indices_or_sections=4, axis=-1)
310
+ f_gate = functional.sigmoid(f + 1.) # Note the +1 bias
311
+
312
+ # Forget gate should be biased towards remembering (> 0.5)
313
+ self.assertTrue(jnp.mean(f_gate) > 0.5)
314
+
315
+ def test_gate_values_range(self):
316
+ """Test that all gates produce values in [0, 1]."""
317
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
318
+ cell.init_state(batch_size=self.batch_size)
319
+
320
+ xh = jnp.concatenate([self.x, cell.h.value], axis=-1)
321
+ i, g, f, o = jnp.split(cell.W(xh), indices_or_sections=4, axis=-1)
322
+
323
+ i_gate = functional.sigmoid(i)
324
+ f_gate = functional.sigmoid(f + 1.)
325
+ o_gate = functional.sigmoid(o)
326
+
327
+ for gate in [i_gate, f_gate, o_gate]:
328
+ self.assertTrue(jnp.all(gate >= 0))
329
+ self.assertTrue(jnp.all(gate <= 1))
330
+
331
+ def test_cell_state_gradient_flow(self):
332
+ """Test gradient flow through cell state."""
333
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
334
+ cell.init_state(batch_size=self.batch_size)
335
+
336
+ def loss_fn(x):
337
+ for t in range(10):
338
+ _ = cell.update(x)
339
+ return jnp.mean(cell.c.value ** 2)
340
+
341
+ grad_fn = jax.grad(loss_fn)
342
+ grad = grad_fn(self.x)
343
+
344
+ self.assertFalse(jnp.any(jnp.isnan(grad)))
345
+ self.assertTrue(jnp.any(grad != 0))
346
+
347
+
348
+ class TestURLSTMCell(TestRNNCellBase):
349
+ """Comprehensive tests for URLSTMCell."""
350
+
351
+ def test_basic_forward(self):
352
+ """Test basic forward pass."""
353
+ cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
354
+ cell.init_state(batch_size=self.batch_size)
355
+
356
+ output = cell.update(self.x)
357
+
358
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
359
+ self.assertFalse(jnp.any(jnp.isnan(output)))
360
+
361
+ def test_untied_bias_mechanism(self):
362
+ """Test the untied bias initialization."""
363
+ cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
364
+ cell.init_state(batch_size=self.batch_size)
365
+
366
+ # Check bias values are initialized
367
+ self.assertIsNotNone(cell.bias.value)
368
+ self.assertEqual(cell.bias.value.shape, (self.num_out,))
369
+
370
+ # Biases should be diverse (not all the same)
371
+ self.assertGreater(jnp.std(cell.bias.value), 0.1)
372
+
373
+ def test_unified_gate_computation(self):
374
+ """Test the unified gate mechanism."""
375
+ cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
376
+ cell.init_state(batch_size=self.batch_size)
377
+
378
+ h, c = cell.h.value, cell.c.value
379
+ xh = jnp.concatenate([self.x, h], axis=-1)
380
+ gates = cell.W(xh)
381
+ f, r, u, o = jnp.split(gates, indices_or_sections=4, axis=-1)
382
+
383
+ f_gate = functional.sigmoid(f + cell.bias.value)
384
+ r_gate = functional.sigmoid(r - cell.bias.value)
385
+
386
+ # Compute unified gate
387
+ g = 2 * r_gate * f_gate + (1 - 2 * r_gate) * f_gate ** 2
388
+
389
+ # Unified gate should be in [0, 1]
390
+ self.assertTrue(jnp.all(g >= 0))
391
+ self.assertTrue(jnp.all(g <= 1))
392
+
393
+ def test_comparison_with_lstm(self):
394
+ """Test that URLSTM behaves differently from standard LSTM."""
395
+ urlstm = URLSTMCell(num_in=self.num_in, num_out=self.num_out, state_init=init.Constant(0.5))
396
+ lstm = LSTMCell(num_in=self.num_in, num_out=self.num_out, state_init=init.Constant(0.5))
397
+
398
+ urlstm.init_state(batch_size=self.batch_size)
399
+ lstm.init_state(batch_size=self.batch_size)
400
+
401
+ # Same input should produce different outputs
402
+ urlstm_out = urlstm.update(self.x)
403
+ lstm_out = lstm.update(self.x)
404
+
405
+ self.assertFalse(jnp.allclose(urlstm_out, lstm_out, atol=1e-4))
406
+
407
+
408
+ class TestRNNCellIntegration(TestRNNCellBase):
409
+ """Integration tests for all RNN cells."""
410
+
411
+ def test_all_cells_compatible_interface(self):
412
+ """Test that all cells have compatible interfaces."""
413
+ cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
414
+
415
+ for CellType in cell_types:
416
+ cell = CellType(num_in=self.num_in, num_out=self.num_out)
417
+
418
+ # Test init_state
419
+ cell.init_state(batch_size=self.batch_size)
420
+
421
+ # Test update
422
+ output = cell.update(self.x)
423
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
424
+
425
+ # Test reset_state
426
+ cell.reset_state(batch_size=16)
427
+
428
+ # Test with new batch size
429
+ x_small = jnp.ones((16, self.num_in))
430
+ output_small = cell.update(x_small)
431
+ self.assertEqual(output_small.shape, (16, self.num_out))
432
+
433
+ def test_sequence_to_sequence(self):
434
+ """Test sequence-to-sequence processing."""
435
+ cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
436
+
437
+ for CellType in cell_types:
438
+ cell = CellType(num_in=self.num_in, num_out=self.num_out)
439
+ cell.init_state(batch_size=self.batch_size)
440
+
441
+ outputs = []
442
+ for t in range(self.sequence_length):
443
+ output = cell.update(self.sequence[t])
444
+ outputs.append(output)
445
+
446
+ outputs = jnp.stack(outputs)
447
+ self.assertEqual(
448
+ outputs.shape,
449
+ (self.sequence_length, self.batch_size, self.num_out)
450
+ )
451
+
452
+ def test_variable_length_sequences(self):
453
+ """Test handling of variable length sequences with masking."""
454
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
455
+ cell.init_state(batch_size=self.batch_size)
456
+
457
+ # Create mask for variable lengths
458
+ lengths = jnp.array([10, 20, 30, 40] * (self.batch_size // 4))
459
+ mask = jnp.arange(self.sequence_length)[:, None] < lengths[None, :]
460
+
461
+ outputs = []
462
+ for t in range(self.sequence_length):
463
+ output = cell.update(self.sequence[t])
464
+ # Apply mask
465
+ output = output * mask[t:t+1].T
466
+ outputs.append(output)
467
+
468
+ outputs = jnp.stack(outputs)
469
+
470
+ # Check that masked positions are zero
471
+ for b in range(self.batch_size):
472
+ length = lengths[b]
473
+ if length < self.sequence_length:
474
+ self.assertTrue(jnp.allclose(outputs[length:, b, :], 0.0))
475
+
476
+ def test_gradient_clipping(self):
477
+ """Test gradient clipping during training."""
478
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
479
+ cell.init_state(batch_size=self.batch_size)
480
+
481
+ def loss_fn(x):
482
+ output = jnp.zeros((self.batch_size, self.num_out))
483
+ for t in range(50): # Long sequence
484
+ output = cell.update(x * (t + 1)) # Amplify input
485
+ return jnp.mean(output ** 2)
486
+
487
+ grad_fn = jax.grad(loss_fn)
488
+ grad = grad_fn(self.x)
489
+
490
+ # Gradients should not explode
491
+ self.assertFalse(jnp.any(jnp.isnan(grad)))
492
+ self.assertFalse(jnp.any(jnp.isinf(grad)))
493
+ self.assertLess(jnp.max(jnp.abs(grad)), 1e6)
494
+
495
+
496
+ class TestRNNCellEdgeCases(TestRNNCellBase):
497
+ """Edge case tests for RNN cells."""
498
+
499
+ def test_single_sample(self):
500
+ """Test with batch size of 1."""
501
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
502
+ cell.init_state(batch_size=1)
503
+
504
+ x = jnp.ones((1, self.num_in))
505
+ output = cell.update(x)
506
+ self.assertEqual(output.shape, (1, self.num_out))
507
+
508
+ def test_zero_input(self):
509
+ """Test with zero inputs."""
510
+ cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
511
+
512
+ for CellType in cell_types:
513
+ cell = CellType(num_in=self.num_in, num_out=self.num_out)
514
+ cell.init_state(batch_size=self.batch_size)
515
+
516
+ zero_input = jnp.zeros((self.batch_size, self.num_in))
517
+ output = cell.update(zero_input)
518
+
519
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
520
+ self.assertFalse(jnp.any(jnp.isnan(output)))
521
+
522
+ def test_large_input_values(self):
523
+ """Test with large input values."""
524
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
525
+ cell.init_state(batch_size=self.batch_size)
526
+
527
+ large_input = jnp.ones((self.batch_size, self.num_in)) * 100
528
+ output = cell.update(large_input)
529
+
530
+ # Should handle large inputs gracefully (sigmoid saturation)
531
+ self.assertFalse(jnp.any(jnp.isnan(output)))
532
+ self.assertFalse(jnp.any(jnp.isinf(output)))
533
+
534
+ def test_very_long_sequence(self):
535
+ """Test with very long sequences."""
536
+ cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
537
+ cell.init_state(batch_size=4) # Smaller batch for memory
538
+
539
+ long_sequence = jnp.ones((1000, 4, self.num_in))
540
+
541
+ final_output = None
542
+ for t in range(1000):
543
+ final_output = cell.update(long_sequence[t])
544
+
545
+ # Should not have numerical issues even after long sequence
546
+ self.assertFalse(jnp.any(jnp.isnan(final_output)))
547
+ self.assertFalse(jnp.any(jnp.isinf(final_output)))
548
+
549
+ def test_dimension_mismatch_error(self):
550
+ """Test that dimension mismatches raise appropriate errors."""
551
+ cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
552
+ cell.init_state(batch_size=self.batch_size)
553
+
554
+ # Wrong input dimension should raise error
555
+ wrong_input = jnp.ones((self.batch_size, self.num_in + 5))
556
+
557
+ with self.assertRaises(Exception):
558
+ _ = cell.update(wrong_input)
559
+
560
+
561
+ class TestRNNCellProperties(TestRNNCellBase):
562
+ """Test cell properties and attributes."""
563
+
564
+ def test_cell_attributes(self):
565
+ """Test that cells have correct attributes."""
566
+ cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
567
+
568
+ self.assertEqual(cell.num_in, self.num_in)
569
+ self.assertEqual(cell.num_out, self.num_out)
570
+ self.assertEqual(cell.in_size, (self.num_in,))
571
+ self.assertEqual(cell.out_size, (self.num_out,))
572
+
573
+ def test_inheritance_structure(self):
574
+ """Test that all cells inherit from RNNCell."""
575
+ cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
576
+
577
+ for CellType in cell_types:
578
+ cell = CellType(num_in=self.num_in, num_out=self.num_out)
579
+ self.assertIsInstance(cell, RNNCell)
580
+
581
+ def test_docstring_presence(self):
582
+ """Test that all cells have proper docstrings."""
583
+ cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
584
+
585
+ for CellType in cell_types:
586
+ self.assertIsNotNone(CellType.__doc__)
587
+ self.assertIn("Examples", CellType.__doc__)
588
+ self.assertIn("Parameters", CellType.__doc__)
589
+ self.assertIn(">>>", CellType.__doc__)
590
+
591
+
592
+ if __name__ == '__main__':
593
+ unittest.main()