brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2319 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +1652 -1652
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1624 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1433 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +137 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +633 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +154 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +477 -477
- brainstate/nn/_dynamics.py +1267 -1267
- brainstate/nn/_dynamics_test.py +67 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +384 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/_rand_funs.py +3938 -3938
- brainstate/random/_rand_funs_test.py +640 -640
- brainstate/random/_rand_seed.py +675 -675
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1617
- brainstate/random/_rand_state_test.py +551 -551
- brainstate/transform/__init__.py +59 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -145
- brainstate/transform/_eval_shape_test.py +38 -38
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -2016
- brainstate/transform/_make_jaxpr_test.py +1510 -1510
- brainstate/transform/_mapping.py +529 -529
- brainstate/transform/_mapping_test.py +194 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_random.py +171 -171
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate-0.2.0.dist-info/RECORD +0 -111
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/_rnns_test.py
CHANGED
@@ -1,593 +1,593 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""Comprehensive tests for RNN cell implementations."""
|
17
|
-
|
18
|
-
import unittest
|
19
|
-
from typing import Type
|
20
|
-
|
21
|
-
import jax
|
22
|
-
import jax.numpy as jnp
|
23
|
-
import numpy as np
|
24
|
-
|
25
|
-
import brainstate
|
26
|
-
import brainstate.nn as nn
|
27
|
-
from brainstate.nn import RNNCell, ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell
|
28
|
-
from brainstate.nn import init as init
|
29
|
-
from brainstate.nn import _activations as functional
|
30
|
-
|
31
|
-
|
32
|
-
class TestRNNCellBase(unittest.TestCase):
|
33
|
-
"""Base test class for all RNN cell implementations."""
|
34
|
-
|
35
|
-
def setUp(self):
|
36
|
-
"""Set up test fixtures."""
|
37
|
-
self.num_in = 10
|
38
|
-
self.num_out = 20
|
39
|
-
self.batch_size = 32
|
40
|
-
self.sequence_length = 100
|
41
|
-
self.seed = 42
|
42
|
-
|
43
|
-
# Initialize random inputs
|
44
|
-
key = jax.random.PRNGKey(self.seed)
|
45
|
-
self.x = jax.random.normal(key, (self.batch_size, self.num_in))
|
46
|
-
self.sequence = jax.random.normal(
|
47
|
-
key, (self.sequence_length, self.batch_size, self.num_in)
|
48
|
-
)
|
49
|
-
|
50
|
-
|
51
|
-
class TestVanillaRNNCell(TestRNNCellBase):
|
52
|
-
"""Comprehensive tests for VanillaRNNCell."""
|
53
|
-
|
54
|
-
def test_basic_forward(self):
|
55
|
-
"""Test basic forward pass."""
|
56
|
-
cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
|
57
|
-
cell.init_state(batch_size=self.batch_size)
|
58
|
-
|
59
|
-
output = cell.update(self.x)
|
60
|
-
|
61
|
-
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
62
|
-
self.assertFalse(jnp.any(jnp.isnan(output)))
|
63
|
-
self.assertFalse(jnp.any(jnp.isinf(output)))
|
64
|
-
|
65
|
-
def test_sequence_processing(self):
|
66
|
-
"""Test processing a sequence of inputs."""
|
67
|
-
cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
|
68
|
-
cell.init_state(batch_size=self.batch_size)
|
69
|
-
|
70
|
-
outputs = []
|
71
|
-
for t in range(self.sequence_length):
|
72
|
-
output = cell.update(self.sequence[t])
|
73
|
-
outputs.append(output)
|
74
|
-
|
75
|
-
outputs = jnp.stack(outputs)
|
76
|
-
self.assertEqual(outputs.shape, (self.sequence_length, self.batch_size, self.num_out))
|
77
|
-
self.assertFalse(jnp.any(jnp.isnan(outputs)))
|
78
|
-
|
79
|
-
def test_state_reset(self):
|
80
|
-
"""Test state reset functionality."""
|
81
|
-
cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
|
82
|
-
cell.init_state(batch_size=self.batch_size)
|
83
|
-
|
84
|
-
# Process some input
|
85
|
-
_ = cell.update(self.x)
|
86
|
-
state_before = cell.h.value.copy()
|
87
|
-
|
88
|
-
# Reset state
|
89
|
-
cell.reset_state(batch_size=self.batch_size)
|
90
|
-
state_after = cell.h.value.copy()
|
91
|
-
|
92
|
-
# States should be different (unless randomly the same, which is unlikely)
|
93
|
-
self.assertFalse(jnp.allclose(state_before, state_after, atol=1e-6))
|
94
|
-
|
95
|
-
def test_different_batch_sizes(self):
|
96
|
-
"""Test with different batch sizes."""
|
97
|
-
cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
|
98
|
-
|
99
|
-
for batch_size in [1, 16, 64]:
|
100
|
-
cell.init_state(batch_size=batch_size)
|
101
|
-
x = jnp.ones((batch_size, self.num_in))
|
102
|
-
output = cell.update(x)
|
103
|
-
self.assertEqual(output.shape, (batch_size, self.num_out))
|
104
|
-
|
105
|
-
def test_activation_functions(self):
|
106
|
-
"""Test different activation functions."""
|
107
|
-
activations = ['relu', 'tanh', 'sigmoid', 'gelu']
|
108
|
-
|
109
|
-
for activation in activations:
|
110
|
-
cell = ValinaRNNCell(
|
111
|
-
num_in=self.num_in,
|
112
|
-
num_out=self.num_out,
|
113
|
-
activation=activation
|
114
|
-
)
|
115
|
-
cell.init_state(batch_size=self.batch_size)
|
116
|
-
output = cell.update(self.x)
|
117
|
-
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
118
|
-
self.assertFalse(jnp.any(jnp.isnan(output)))
|
119
|
-
|
120
|
-
def test_custom_initializers(self):
|
121
|
-
"""Test custom weight and state initializers."""
|
122
|
-
cell = ValinaRNNCell(
|
123
|
-
num_in=self.num_in,
|
124
|
-
num_out=self.num_out,
|
125
|
-
w_init=init.Orthogonal(),
|
126
|
-
b_init=init.Constant(0.1),
|
127
|
-
state_init=init.Normal(0.01)
|
128
|
-
)
|
129
|
-
cell.init_state(batch_size=self.batch_size)
|
130
|
-
output = cell.update(self.x)
|
131
|
-
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
132
|
-
|
133
|
-
def test_gradient_flow(self):
|
134
|
-
"""Test gradient flow through the cell."""
|
135
|
-
cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
|
136
|
-
cell.init_state(batch_size=self.batch_size)
|
137
|
-
|
138
|
-
def loss_fn(x):
|
139
|
-
output = cell.update(x)
|
140
|
-
return jnp.mean(output ** 2)
|
141
|
-
|
142
|
-
grad_fn = jax.grad(loss_fn)
|
143
|
-
grad = grad_fn(self.x)
|
144
|
-
|
145
|
-
self.assertEqual(grad.shape, self.x.shape)
|
146
|
-
self.assertFalse(jnp.any(jnp.isnan(grad)))
|
147
|
-
self.assertTrue(jnp.any(grad != 0)) # Gradients should be non-zero
|
148
|
-
|
149
|
-
|
150
|
-
class TestGRUCell(TestRNNCellBase):
|
151
|
-
"""Comprehensive tests for GRUCell."""
|
152
|
-
|
153
|
-
def test_basic_forward(self):
|
154
|
-
"""Test basic forward pass."""
|
155
|
-
cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
|
156
|
-
cell.init_state(batch_size=self.batch_size)
|
157
|
-
|
158
|
-
output = cell.update(self.x)
|
159
|
-
|
160
|
-
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
161
|
-
self.assertFalse(jnp.any(jnp.isnan(output)))
|
162
|
-
|
163
|
-
def test_gating_mechanism(self):
|
164
|
-
"""Test that gating values are in valid range."""
|
165
|
-
cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
|
166
|
-
cell.init_state(batch_size=self.batch_size)
|
167
|
-
|
168
|
-
# Access internal computation
|
169
|
-
old_h = cell.h.value
|
170
|
-
xh = jnp.concatenate([self.x, old_h], axis=-1)
|
171
|
-
gates = functional.sigmoid(cell.Wrz(xh))
|
172
|
-
|
173
|
-
# Gates should be between 0 and 1
|
174
|
-
self.assertTrue(jnp.all(gates >= 0))
|
175
|
-
self.assertTrue(jnp.all(gates <= 1))
|
176
|
-
|
177
|
-
def test_state_persistence(self):
|
178
|
-
"""Test that state persists across updates."""
|
179
|
-
cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
|
180
|
-
cell.init_state(batch_size=self.batch_size)
|
181
|
-
|
182
|
-
# Process sequence and track states
|
183
|
-
states = []
|
184
|
-
for t in range(10):
|
185
|
-
_ = cell.update(self.sequence[t])
|
186
|
-
states.append(cell.h.value.copy())
|
187
|
-
|
188
|
-
# States should evolve over time
|
189
|
-
for i in range(1, len(states)):
|
190
|
-
self.assertFalse(jnp.allclose(states[i], states[i-1], atol=1e-8))
|
191
|
-
|
192
|
-
def test_reset_vs_update_gates(self):
|
193
|
-
"""Test that reset and update gates behave differently."""
|
194
|
-
cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
|
195
|
-
cell.init_state(batch_size=self.batch_size)
|
196
|
-
|
197
|
-
# Get gates for the same input
|
198
|
-
old_h = cell.h.value
|
199
|
-
xh = jnp.concatenate([self.x, old_h], axis=-1)
|
200
|
-
r, z = jnp.split(functional.sigmoid(cell.Wrz(xh)), indices_or_sections=2, axis=-1)
|
201
|
-
|
202
|
-
# Reset and update gates should be different
|
203
|
-
self.assertFalse(jnp.allclose(r, z, atol=1e-6))
|
204
|
-
|
205
|
-
def test_different_initializers(self):
|
206
|
-
"""Test with different weight initializers."""
|
207
|
-
initializers = [
|
208
|
-
init.XavierNormal(),
|
209
|
-
init.XavierUniform(),
|
210
|
-
init.Orthogonal(),
|
211
|
-
init.KaimingNormal(),
|
212
|
-
]
|
213
|
-
|
214
|
-
for w_init in initializers:
|
215
|
-
cell = GRUCell(
|
216
|
-
num_in=self.num_in,
|
217
|
-
num_out=self.num_out,
|
218
|
-
w_init=w_init
|
219
|
-
)
|
220
|
-
cell.init_state(batch_size=self.batch_size)
|
221
|
-
output = cell.update(self.x)
|
222
|
-
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
223
|
-
|
224
|
-
|
225
|
-
class TestMGUCell(TestRNNCellBase):
|
226
|
-
"""Comprehensive tests for MGUCell."""
|
227
|
-
|
228
|
-
def test_basic_forward(self):
|
229
|
-
"""Test basic forward pass."""
|
230
|
-
cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
|
231
|
-
cell.init_state(batch_size=self.batch_size)
|
232
|
-
|
233
|
-
output = cell.update(self.x)
|
234
|
-
|
235
|
-
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
236
|
-
self.assertFalse(jnp.any(jnp.isnan(output)))
|
237
|
-
|
238
|
-
def test_single_gate_mechanism(self):
|
239
|
-
"""Test that MGU uses single forget gate."""
|
240
|
-
cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
|
241
|
-
cell.init_state(batch_size=self.batch_size)
|
242
|
-
|
243
|
-
# Check that only one gate is computed
|
244
|
-
xh = jnp.concatenate([self.x, cell.h.value], axis=-1)
|
245
|
-
f = functional.sigmoid(cell.Wf(xh))
|
246
|
-
|
247
|
-
# Forget gate should be between 0 and 1
|
248
|
-
self.assertTrue(jnp.all(f >= 0))
|
249
|
-
self.assertTrue(jnp.all(f <= 1))
|
250
|
-
self.assertEqual(f.shape, (self.batch_size, self.num_out))
|
251
|
-
|
252
|
-
def test_parameter_efficiency(self):
|
253
|
-
"""Test that MGU has fewer parameters than GRU."""
|
254
|
-
mgu_cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
|
255
|
-
gru_cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
|
256
|
-
|
257
|
-
# Count parameters - MGU should have fewer
|
258
|
-
# MGU has 2 weight matrices (Wf, Wh)
|
259
|
-
# GRU has 2 weight matrices but one is double size (Wrz, Wh)
|
260
|
-
mgu_param_count = 2 * ((self.num_in + self.num_out) * self.num_out + self.num_out)
|
261
|
-
gru_param_count = ((self.num_in + self.num_out) * (self.num_out * 2) + self.num_out * 2) + \
|
262
|
-
((self.num_in + self.num_out) * self.num_out + self.num_out)
|
263
|
-
|
264
|
-
self.assertLess(mgu_param_count, gru_param_count)
|
265
|
-
|
266
|
-
|
267
|
-
class TestLSTMCell(TestRNNCellBase):
|
268
|
-
"""Comprehensive tests for LSTMCell."""
|
269
|
-
|
270
|
-
def test_basic_forward(self):
|
271
|
-
"""Test basic forward pass."""
|
272
|
-
cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
273
|
-
cell.init_state(batch_size=self.batch_size)
|
274
|
-
|
275
|
-
output = cell.update(self.x)
|
276
|
-
|
277
|
-
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
278
|
-
self.assertFalse(jnp.any(jnp.isnan(output)))
|
279
|
-
|
280
|
-
def test_dual_state_mechanism(self):
|
281
|
-
"""Test that LSTM maintains both hidden and cell states."""
|
282
|
-
cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
283
|
-
cell.init_state(batch_size=self.batch_size)
|
284
|
-
|
285
|
-
# Check initial states
|
286
|
-
self.assertIsNotNone(cell.h)
|
287
|
-
self.assertIsNotNone(cell.c)
|
288
|
-
self.assertEqual(cell.h.value.shape, (self.batch_size, self.num_out))
|
289
|
-
self.assertEqual(cell.c.value.shape, (self.batch_size, self.num_out))
|
290
|
-
|
291
|
-
# Update and check states change
|
292
|
-
h_before = cell.h.value.copy()
|
293
|
-
c_before = cell.c.value.copy()
|
294
|
-
|
295
|
-
_ = cell.update(self.x)
|
296
|
-
|
297
|
-
self.assertFalse(jnp.allclose(cell.h.value, h_before, atol=1e-8))
|
298
|
-
self.assertFalse(jnp.allclose(cell.c.value, c_before, atol=1e-8))
|
299
|
-
|
300
|
-
def test_forget_gate_bias(self):
|
301
|
-
"""Test that forget gate has positive bias initialization."""
|
302
|
-
cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
303
|
-
cell.init_state(batch_size=self.batch_size)
|
304
|
-
|
305
|
-
# Process with zero input to see bias effect
|
306
|
-
zero_input = jnp.zeros((self.batch_size, self.num_in))
|
307
|
-
xh = jnp.concatenate([zero_input, cell.h.value], axis=-1)
|
308
|
-
gates = cell.W(xh)
|
309
|
-
_, _, f, _ = jnp.split(gates, indices_or_sections=4, axis=-1)
|
310
|
-
f_gate = functional.sigmoid(f + 1.) # Note the +1 bias
|
311
|
-
|
312
|
-
# Forget gate should be biased towards remembering (> 0.5)
|
313
|
-
self.assertTrue(jnp.mean(f_gate) > 0.5)
|
314
|
-
|
315
|
-
def test_gate_values_range(self):
|
316
|
-
"""Test that all gates produce values in [0, 1]."""
|
317
|
-
cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
318
|
-
cell.init_state(batch_size=self.batch_size)
|
319
|
-
|
320
|
-
xh = jnp.concatenate([self.x, cell.h.value], axis=-1)
|
321
|
-
i, g, f, o = jnp.split(cell.W(xh), indices_or_sections=4, axis=-1)
|
322
|
-
|
323
|
-
i_gate = functional.sigmoid(i)
|
324
|
-
f_gate = functional.sigmoid(f + 1.)
|
325
|
-
o_gate = functional.sigmoid(o)
|
326
|
-
|
327
|
-
for gate in [i_gate, f_gate, o_gate]:
|
328
|
-
self.assertTrue(jnp.all(gate >= 0))
|
329
|
-
self.assertTrue(jnp.all(gate <= 1))
|
330
|
-
|
331
|
-
def test_cell_state_gradient_flow(self):
|
332
|
-
"""Test gradient flow through cell state."""
|
333
|
-
cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
334
|
-
cell.init_state(batch_size=self.batch_size)
|
335
|
-
|
336
|
-
def loss_fn(x):
|
337
|
-
for t in range(10):
|
338
|
-
_ = cell.update(x)
|
339
|
-
return jnp.mean(cell.c.value ** 2)
|
340
|
-
|
341
|
-
grad_fn = jax.grad(loss_fn)
|
342
|
-
grad = grad_fn(self.x)
|
343
|
-
|
344
|
-
self.assertFalse(jnp.any(jnp.isnan(grad)))
|
345
|
-
self.assertTrue(jnp.any(grad != 0))
|
346
|
-
|
347
|
-
|
348
|
-
class TestURLSTMCell(TestRNNCellBase):
|
349
|
-
"""Comprehensive tests for URLSTMCell."""
|
350
|
-
|
351
|
-
def test_basic_forward(self):
|
352
|
-
"""Test basic forward pass."""
|
353
|
-
cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
|
354
|
-
cell.init_state(batch_size=self.batch_size)
|
355
|
-
|
356
|
-
output = cell.update(self.x)
|
357
|
-
|
358
|
-
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
359
|
-
self.assertFalse(jnp.any(jnp.isnan(output)))
|
360
|
-
|
361
|
-
def test_untied_bias_mechanism(self):
|
362
|
-
"""Test the untied bias initialization."""
|
363
|
-
cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
|
364
|
-
cell.init_state(batch_size=self.batch_size)
|
365
|
-
|
366
|
-
# Check bias values are initialized
|
367
|
-
self.assertIsNotNone(cell.bias.value)
|
368
|
-
self.assertEqual(cell.bias.value.shape, (self.num_out,))
|
369
|
-
|
370
|
-
# Biases should be diverse (not all the same)
|
371
|
-
self.assertGreater(jnp.std(cell.bias.value), 0.1)
|
372
|
-
|
373
|
-
def test_unified_gate_computation(self):
|
374
|
-
"""Test the unified gate mechanism."""
|
375
|
-
cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
|
376
|
-
cell.init_state(batch_size=self.batch_size)
|
377
|
-
|
378
|
-
h, c = cell.h.value, cell.c.value
|
379
|
-
xh = jnp.concatenate([self.x, h], axis=-1)
|
380
|
-
gates = cell.W(xh)
|
381
|
-
f, r, u, o = jnp.split(gates, indices_or_sections=4, axis=-1)
|
382
|
-
|
383
|
-
f_gate = functional.sigmoid(f + cell.bias.value)
|
384
|
-
r_gate = functional.sigmoid(r - cell.bias.value)
|
385
|
-
|
386
|
-
# Compute unified gate
|
387
|
-
g = 2 * r_gate * f_gate + (1 - 2 * r_gate) * f_gate ** 2
|
388
|
-
|
389
|
-
# Unified gate should be in [0, 1]
|
390
|
-
self.assertTrue(jnp.all(g >= 0))
|
391
|
-
self.assertTrue(jnp.all(g <= 1))
|
392
|
-
|
393
|
-
def test_comparison_with_lstm(self):
|
394
|
-
"""Test that URLSTM behaves differently from standard LSTM."""
|
395
|
-
urlstm = URLSTMCell(num_in=self.num_in, num_out=self.num_out, state_init=init.Constant(0.5))
|
396
|
-
lstm = LSTMCell(num_in=self.num_in, num_out=self.num_out, state_init=init.Constant(0.5))
|
397
|
-
|
398
|
-
urlstm.init_state(batch_size=self.batch_size)
|
399
|
-
lstm.init_state(batch_size=self.batch_size)
|
400
|
-
|
401
|
-
# Same input should produce different outputs
|
402
|
-
urlstm_out = urlstm.update(self.x)
|
403
|
-
lstm_out = lstm.update(self.x)
|
404
|
-
|
405
|
-
self.assertFalse(jnp.allclose(urlstm_out, lstm_out, atol=1e-4))
|
406
|
-
|
407
|
-
|
408
|
-
class TestRNNCellIntegration(TestRNNCellBase):
|
409
|
-
"""Integration tests for all RNN cells."""
|
410
|
-
|
411
|
-
def test_all_cells_compatible_interface(self):
|
412
|
-
"""Test that all cells have compatible interfaces."""
|
413
|
-
cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
|
414
|
-
|
415
|
-
for CellType in cell_types:
|
416
|
-
cell = CellType(num_in=self.num_in, num_out=self.num_out)
|
417
|
-
|
418
|
-
# Test init_state
|
419
|
-
cell.init_state(batch_size=self.batch_size)
|
420
|
-
|
421
|
-
# Test update
|
422
|
-
output = cell.update(self.x)
|
423
|
-
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
424
|
-
|
425
|
-
# Test reset_state
|
426
|
-
cell.reset_state(batch_size=16)
|
427
|
-
|
428
|
-
# Test with new batch size
|
429
|
-
x_small = jnp.ones((16, self.num_in))
|
430
|
-
output_small = cell.update(x_small)
|
431
|
-
self.assertEqual(output_small.shape, (16, self.num_out))
|
432
|
-
|
433
|
-
def test_sequence_to_sequence(self):
|
434
|
-
"""Test sequence-to-sequence processing."""
|
435
|
-
cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
|
436
|
-
|
437
|
-
for CellType in cell_types:
|
438
|
-
cell = CellType(num_in=self.num_in, num_out=self.num_out)
|
439
|
-
cell.init_state(batch_size=self.batch_size)
|
440
|
-
|
441
|
-
outputs = []
|
442
|
-
for t in range(self.sequence_length):
|
443
|
-
output = cell.update(self.sequence[t])
|
444
|
-
outputs.append(output)
|
445
|
-
|
446
|
-
outputs = jnp.stack(outputs)
|
447
|
-
self.assertEqual(
|
448
|
-
outputs.shape,
|
449
|
-
(self.sequence_length, self.batch_size, self.num_out)
|
450
|
-
)
|
451
|
-
|
452
|
-
def test_variable_length_sequences(self):
|
453
|
-
"""Test handling of variable length sequences with masking."""
|
454
|
-
cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
455
|
-
cell.init_state(batch_size=self.batch_size)
|
456
|
-
|
457
|
-
# Create mask for variable lengths
|
458
|
-
lengths = jnp.array([10, 20, 30, 40] * (self.batch_size // 4))
|
459
|
-
mask = jnp.arange(self.sequence_length)[:, None] < lengths[None, :]
|
460
|
-
|
461
|
-
outputs = []
|
462
|
-
for t in range(self.sequence_length):
|
463
|
-
output = cell.update(self.sequence[t])
|
464
|
-
# Apply mask
|
465
|
-
output = output * mask[t:t+1].T
|
466
|
-
outputs.append(output)
|
467
|
-
|
468
|
-
outputs = jnp.stack(outputs)
|
469
|
-
|
470
|
-
# Check that masked positions are zero
|
471
|
-
for b in range(self.batch_size):
|
472
|
-
length = lengths[b]
|
473
|
-
if length < self.sequence_length:
|
474
|
-
self.assertTrue(jnp.allclose(outputs[length:, b, :], 0.0))
|
475
|
-
|
476
|
-
def test_gradient_clipping(self):
|
477
|
-
"""Test gradient clipping during training."""
|
478
|
-
cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
|
479
|
-
cell.init_state(batch_size=self.batch_size)
|
480
|
-
|
481
|
-
def loss_fn(x):
|
482
|
-
output = jnp.zeros((self.batch_size, self.num_out))
|
483
|
-
for t in range(50): # Long sequence
|
484
|
-
output = cell.update(x * (t + 1)) # Amplify input
|
485
|
-
return jnp.mean(output ** 2)
|
486
|
-
|
487
|
-
grad_fn = jax.grad(loss_fn)
|
488
|
-
grad = grad_fn(self.x)
|
489
|
-
|
490
|
-
# Gradients should not explode
|
491
|
-
self.assertFalse(jnp.any(jnp.isnan(grad)))
|
492
|
-
self.assertFalse(jnp.any(jnp.isinf(grad)))
|
493
|
-
self.assertLess(jnp.max(jnp.abs(grad)), 1e6)
|
494
|
-
|
495
|
-
|
496
|
-
class TestRNNCellEdgeCases(TestRNNCellBase):
|
497
|
-
"""Edge case tests for RNN cells."""
|
498
|
-
|
499
|
-
def test_single_sample(self):
|
500
|
-
"""Test with batch size of 1."""
|
501
|
-
cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
|
502
|
-
cell.init_state(batch_size=1)
|
503
|
-
|
504
|
-
x = jnp.ones((1, self.num_in))
|
505
|
-
output = cell.update(x)
|
506
|
-
self.assertEqual(output.shape, (1, self.num_out))
|
507
|
-
|
508
|
-
def test_zero_input(self):
|
509
|
-
"""Test with zero inputs."""
|
510
|
-
cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
|
511
|
-
|
512
|
-
for CellType in cell_types:
|
513
|
-
cell = CellType(num_in=self.num_in, num_out=self.num_out)
|
514
|
-
cell.init_state(batch_size=self.batch_size)
|
515
|
-
|
516
|
-
zero_input = jnp.zeros((self.batch_size, self.num_in))
|
517
|
-
output = cell.update(zero_input)
|
518
|
-
|
519
|
-
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
520
|
-
self.assertFalse(jnp.any(jnp.isnan(output)))
|
521
|
-
|
522
|
-
def test_large_input_values(self):
|
523
|
-
"""Test with large input values."""
|
524
|
-
cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
525
|
-
cell.init_state(batch_size=self.batch_size)
|
526
|
-
|
527
|
-
large_input = jnp.ones((self.batch_size, self.num_in)) * 100
|
528
|
-
output = cell.update(large_input)
|
529
|
-
|
530
|
-
# Should handle large inputs gracefully (sigmoid saturation)
|
531
|
-
self.assertFalse(jnp.any(jnp.isnan(output)))
|
532
|
-
self.assertFalse(jnp.any(jnp.isinf(output)))
|
533
|
-
|
534
|
-
def test_very_long_sequence(self):
|
535
|
-
"""Test with very long sequences."""
|
536
|
-
cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
|
537
|
-
cell.init_state(batch_size=4) # Smaller batch for memory
|
538
|
-
|
539
|
-
long_sequence = jnp.ones((1000, 4, self.num_in))
|
540
|
-
|
541
|
-
final_output = None
|
542
|
-
for t in range(1000):
|
543
|
-
final_output = cell.update(long_sequence[t])
|
544
|
-
|
545
|
-
# Should not have numerical issues even after long sequence
|
546
|
-
self.assertFalse(jnp.any(jnp.isnan(final_output)))
|
547
|
-
self.assertFalse(jnp.any(jnp.isinf(final_output)))
|
548
|
-
|
549
|
-
def test_dimension_mismatch_error(self):
|
550
|
-
"""Test that dimension mismatches raise appropriate errors."""
|
551
|
-
cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
|
552
|
-
cell.init_state(batch_size=self.batch_size)
|
553
|
-
|
554
|
-
# Wrong input dimension should raise error
|
555
|
-
wrong_input = jnp.ones((self.batch_size, self.num_in + 5))
|
556
|
-
|
557
|
-
with self.assertRaises(Exception):
|
558
|
-
_ = cell.update(wrong_input)
|
559
|
-
|
560
|
-
|
561
|
-
class TestRNNCellProperties(TestRNNCellBase):
|
562
|
-
"""Test cell properties and attributes."""
|
563
|
-
|
564
|
-
def test_cell_attributes(self):
|
565
|
-
"""Test that cells have correct attributes."""
|
566
|
-
cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
567
|
-
|
568
|
-
self.assertEqual(cell.num_in, self.num_in)
|
569
|
-
self.assertEqual(cell.num_out, self.num_out)
|
570
|
-
self.assertEqual(cell.in_size, (self.num_in,))
|
571
|
-
self.assertEqual(cell.out_size, (self.num_out,))
|
572
|
-
|
573
|
-
def test_inheritance_structure(self):
|
574
|
-
"""Test that all cells inherit from RNNCell."""
|
575
|
-
cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
|
576
|
-
|
577
|
-
for CellType in cell_types:
|
578
|
-
cell = CellType(num_in=self.num_in, num_out=self.num_out)
|
579
|
-
self.assertIsInstance(cell, RNNCell)
|
580
|
-
|
581
|
-
def test_docstring_presence(self):
|
582
|
-
"""Test that all cells have proper docstrings."""
|
583
|
-
cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
|
584
|
-
|
585
|
-
for CellType in cell_types:
|
586
|
-
self.assertIsNotNone(CellType.__doc__)
|
587
|
-
self.assertIn("Examples", CellType.__doc__)
|
588
|
-
self.assertIn("Parameters", CellType.__doc__)
|
589
|
-
self.assertIn(">>>", CellType.__doc__)
|
590
|
-
|
591
|
-
|
592
|
-
if __name__ == '__main__':
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Comprehensive tests for RNN cell implementations."""
|
17
|
+
|
18
|
+
import unittest
|
19
|
+
from typing import Type
|
20
|
+
|
21
|
+
import jax
|
22
|
+
import jax.numpy as jnp
|
23
|
+
import numpy as np
|
24
|
+
|
25
|
+
import brainstate
|
26
|
+
import brainstate.nn as nn
|
27
|
+
from brainstate.nn import RNNCell, ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell
|
28
|
+
from brainstate.nn import init as init
|
29
|
+
from brainstate.nn import _activations as functional
|
30
|
+
|
31
|
+
|
32
|
+
class TestRNNCellBase(unittest.TestCase):
|
33
|
+
"""Base test class for all RNN cell implementations."""
|
34
|
+
|
35
|
+
def setUp(self):
|
36
|
+
"""Set up test fixtures."""
|
37
|
+
self.num_in = 10
|
38
|
+
self.num_out = 20
|
39
|
+
self.batch_size = 32
|
40
|
+
self.sequence_length = 100
|
41
|
+
self.seed = 42
|
42
|
+
|
43
|
+
# Initialize random inputs
|
44
|
+
key = jax.random.PRNGKey(self.seed)
|
45
|
+
self.x = jax.random.normal(key, (self.batch_size, self.num_in))
|
46
|
+
self.sequence = jax.random.normal(
|
47
|
+
key, (self.sequence_length, self.batch_size, self.num_in)
|
48
|
+
)
|
49
|
+
|
50
|
+
|
51
|
+
class TestVanillaRNNCell(TestRNNCellBase):
|
52
|
+
"""Comprehensive tests for VanillaRNNCell."""
|
53
|
+
|
54
|
+
def test_basic_forward(self):
|
55
|
+
"""Test basic forward pass."""
|
56
|
+
cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
|
57
|
+
cell.init_state(batch_size=self.batch_size)
|
58
|
+
|
59
|
+
output = cell.update(self.x)
|
60
|
+
|
61
|
+
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
62
|
+
self.assertFalse(jnp.any(jnp.isnan(output)))
|
63
|
+
self.assertFalse(jnp.any(jnp.isinf(output)))
|
64
|
+
|
65
|
+
def test_sequence_processing(self):
|
66
|
+
"""Test processing a sequence of inputs."""
|
67
|
+
cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
|
68
|
+
cell.init_state(batch_size=self.batch_size)
|
69
|
+
|
70
|
+
outputs = []
|
71
|
+
for t in range(self.sequence_length):
|
72
|
+
output = cell.update(self.sequence[t])
|
73
|
+
outputs.append(output)
|
74
|
+
|
75
|
+
outputs = jnp.stack(outputs)
|
76
|
+
self.assertEqual(outputs.shape, (self.sequence_length, self.batch_size, self.num_out))
|
77
|
+
self.assertFalse(jnp.any(jnp.isnan(outputs)))
|
78
|
+
|
79
|
+
def test_state_reset(self):
|
80
|
+
"""Test state reset functionality."""
|
81
|
+
cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
|
82
|
+
cell.init_state(batch_size=self.batch_size)
|
83
|
+
|
84
|
+
# Process some input
|
85
|
+
_ = cell.update(self.x)
|
86
|
+
state_before = cell.h.value.copy()
|
87
|
+
|
88
|
+
# Reset state
|
89
|
+
cell.reset_state(batch_size=self.batch_size)
|
90
|
+
state_after = cell.h.value.copy()
|
91
|
+
|
92
|
+
# States should be different (unless randomly the same, which is unlikely)
|
93
|
+
self.assertFalse(jnp.allclose(state_before, state_after, atol=1e-6))
|
94
|
+
|
95
|
+
def test_different_batch_sizes(self):
|
96
|
+
"""Test with different batch sizes."""
|
97
|
+
cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
|
98
|
+
|
99
|
+
for batch_size in [1, 16, 64]:
|
100
|
+
cell.init_state(batch_size=batch_size)
|
101
|
+
x = jnp.ones((batch_size, self.num_in))
|
102
|
+
output = cell.update(x)
|
103
|
+
self.assertEqual(output.shape, (batch_size, self.num_out))
|
104
|
+
|
105
|
+
def test_activation_functions(self):
|
106
|
+
"""Test different activation functions."""
|
107
|
+
activations = ['relu', 'tanh', 'sigmoid', 'gelu']
|
108
|
+
|
109
|
+
for activation in activations:
|
110
|
+
cell = ValinaRNNCell(
|
111
|
+
num_in=self.num_in,
|
112
|
+
num_out=self.num_out,
|
113
|
+
activation=activation
|
114
|
+
)
|
115
|
+
cell.init_state(batch_size=self.batch_size)
|
116
|
+
output = cell.update(self.x)
|
117
|
+
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
118
|
+
self.assertFalse(jnp.any(jnp.isnan(output)))
|
119
|
+
|
120
|
+
def test_custom_initializers(self):
|
121
|
+
"""Test custom weight and state initializers."""
|
122
|
+
cell = ValinaRNNCell(
|
123
|
+
num_in=self.num_in,
|
124
|
+
num_out=self.num_out,
|
125
|
+
w_init=init.Orthogonal(),
|
126
|
+
b_init=init.Constant(0.1),
|
127
|
+
state_init=init.Normal(0.01)
|
128
|
+
)
|
129
|
+
cell.init_state(batch_size=self.batch_size)
|
130
|
+
output = cell.update(self.x)
|
131
|
+
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
132
|
+
|
133
|
+
def test_gradient_flow(self):
|
134
|
+
"""Test gradient flow through the cell."""
|
135
|
+
cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
|
136
|
+
cell.init_state(batch_size=self.batch_size)
|
137
|
+
|
138
|
+
def loss_fn(x):
|
139
|
+
output = cell.update(x)
|
140
|
+
return jnp.mean(output ** 2)
|
141
|
+
|
142
|
+
grad_fn = jax.grad(loss_fn)
|
143
|
+
grad = grad_fn(self.x)
|
144
|
+
|
145
|
+
self.assertEqual(grad.shape, self.x.shape)
|
146
|
+
self.assertFalse(jnp.any(jnp.isnan(grad)))
|
147
|
+
self.assertTrue(jnp.any(grad != 0)) # Gradients should be non-zero
|
148
|
+
|
149
|
+
|
150
|
+
class TestGRUCell(TestRNNCellBase):
|
151
|
+
"""Comprehensive tests for GRUCell."""
|
152
|
+
|
153
|
+
def test_basic_forward(self):
|
154
|
+
"""Test basic forward pass."""
|
155
|
+
cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
|
156
|
+
cell.init_state(batch_size=self.batch_size)
|
157
|
+
|
158
|
+
output = cell.update(self.x)
|
159
|
+
|
160
|
+
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
161
|
+
self.assertFalse(jnp.any(jnp.isnan(output)))
|
162
|
+
|
163
|
+
def test_gating_mechanism(self):
|
164
|
+
"""Test that gating values are in valid range."""
|
165
|
+
cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
|
166
|
+
cell.init_state(batch_size=self.batch_size)
|
167
|
+
|
168
|
+
# Access internal computation
|
169
|
+
old_h = cell.h.value
|
170
|
+
xh = jnp.concatenate([self.x, old_h], axis=-1)
|
171
|
+
gates = functional.sigmoid(cell.Wrz(xh))
|
172
|
+
|
173
|
+
# Gates should be between 0 and 1
|
174
|
+
self.assertTrue(jnp.all(gates >= 0))
|
175
|
+
self.assertTrue(jnp.all(gates <= 1))
|
176
|
+
|
177
|
+
def test_state_persistence(self):
|
178
|
+
"""Test that state persists across updates."""
|
179
|
+
cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
|
180
|
+
cell.init_state(batch_size=self.batch_size)
|
181
|
+
|
182
|
+
# Process sequence and track states
|
183
|
+
states = []
|
184
|
+
for t in range(10):
|
185
|
+
_ = cell.update(self.sequence[t])
|
186
|
+
states.append(cell.h.value.copy())
|
187
|
+
|
188
|
+
# States should evolve over time
|
189
|
+
for i in range(1, len(states)):
|
190
|
+
self.assertFalse(jnp.allclose(states[i], states[i-1], atol=1e-8))
|
191
|
+
|
192
|
+
def test_reset_vs_update_gates(self):
|
193
|
+
"""Test that reset and update gates behave differently."""
|
194
|
+
cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
|
195
|
+
cell.init_state(batch_size=self.batch_size)
|
196
|
+
|
197
|
+
# Get gates for the same input
|
198
|
+
old_h = cell.h.value
|
199
|
+
xh = jnp.concatenate([self.x, old_h], axis=-1)
|
200
|
+
r, z = jnp.split(functional.sigmoid(cell.Wrz(xh)), indices_or_sections=2, axis=-1)
|
201
|
+
|
202
|
+
# Reset and update gates should be different
|
203
|
+
self.assertFalse(jnp.allclose(r, z, atol=1e-6))
|
204
|
+
|
205
|
+
def test_different_initializers(self):
|
206
|
+
"""Test with different weight initializers."""
|
207
|
+
initializers = [
|
208
|
+
init.XavierNormal(),
|
209
|
+
init.XavierUniform(),
|
210
|
+
init.Orthogonal(),
|
211
|
+
init.KaimingNormal(),
|
212
|
+
]
|
213
|
+
|
214
|
+
for w_init in initializers:
|
215
|
+
cell = GRUCell(
|
216
|
+
num_in=self.num_in,
|
217
|
+
num_out=self.num_out,
|
218
|
+
w_init=w_init
|
219
|
+
)
|
220
|
+
cell.init_state(batch_size=self.batch_size)
|
221
|
+
output = cell.update(self.x)
|
222
|
+
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
223
|
+
|
224
|
+
|
225
|
+
class TestMGUCell(TestRNNCellBase):
|
226
|
+
"""Comprehensive tests for MGUCell."""
|
227
|
+
|
228
|
+
def test_basic_forward(self):
|
229
|
+
"""Test basic forward pass."""
|
230
|
+
cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
|
231
|
+
cell.init_state(batch_size=self.batch_size)
|
232
|
+
|
233
|
+
output = cell.update(self.x)
|
234
|
+
|
235
|
+
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
236
|
+
self.assertFalse(jnp.any(jnp.isnan(output)))
|
237
|
+
|
238
|
+
def test_single_gate_mechanism(self):
|
239
|
+
"""Test that MGU uses single forget gate."""
|
240
|
+
cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
|
241
|
+
cell.init_state(batch_size=self.batch_size)
|
242
|
+
|
243
|
+
# Check that only one gate is computed
|
244
|
+
xh = jnp.concatenate([self.x, cell.h.value], axis=-1)
|
245
|
+
f = functional.sigmoid(cell.Wf(xh))
|
246
|
+
|
247
|
+
# Forget gate should be between 0 and 1
|
248
|
+
self.assertTrue(jnp.all(f >= 0))
|
249
|
+
self.assertTrue(jnp.all(f <= 1))
|
250
|
+
self.assertEqual(f.shape, (self.batch_size, self.num_out))
|
251
|
+
|
252
|
+
def test_parameter_efficiency(self):
|
253
|
+
"""Test that MGU has fewer parameters than GRU."""
|
254
|
+
mgu_cell = MGUCell(num_in=self.num_in, num_out=self.num_out)
|
255
|
+
gru_cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
|
256
|
+
|
257
|
+
# Count parameters - MGU should have fewer
|
258
|
+
# MGU has 2 weight matrices (Wf, Wh)
|
259
|
+
# GRU has 2 weight matrices but one is double size (Wrz, Wh)
|
260
|
+
mgu_param_count = 2 * ((self.num_in + self.num_out) * self.num_out + self.num_out)
|
261
|
+
gru_param_count = ((self.num_in + self.num_out) * (self.num_out * 2) + self.num_out * 2) + \
|
262
|
+
((self.num_in + self.num_out) * self.num_out + self.num_out)
|
263
|
+
|
264
|
+
self.assertLess(mgu_param_count, gru_param_count)
|
265
|
+
|
266
|
+
|
267
|
+
class TestLSTMCell(TestRNNCellBase):
|
268
|
+
"""Comprehensive tests for LSTMCell."""
|
269
|
+
|
270
|
+
def test_basic_forward(self):
|
271
|
+
"""Test basic forward pass."""
|
272
|
+
cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
273
|
+
cell.init_state(batch_size=self.batch_size)
|
274
|
+
|
275
|
+
output = cell.update(self.x)
|
276
|
+
|
277
|
+
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
278
|
+
self.assertFalse(jnp.any(jnp.isnan(output)))
|
279
|
+
|
280
|
+
def test_dual_state_mechanism(self):
|
281
|
+
"""Test that LSTM maintains both hidden and cell states."""
|
282
|
+
cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
283
|
+
cell.init_state(batch_size=self.batch_size)
|
284
|
+
|
285
|
+
# Check initial states
|
286
|
+
self.assertIsNotNone(cell.h)
|
287
|
+
self.assertIsNotNone(cell.c)
|
288
|
+
self.assertEqual(cell.h.value.shape, (self.batch_size, self.num_out))
|
289
|
+
self.assertEqual(cell.c.value.shape, (self.batch_size, self.num_out))
|
290
|
+
|
291
|
+
# Update and check states change
|
292
|
+
h_before = cell.h.value.copy()
|
293
|
+
c_before = cell.c.value.copy()
|
294
|
+
|
295
|
+
_ = cell.update(self.x)
|
296
|
+
|
297
|
+
self.assertFalse(jnp.allclose(cell.h.value, h_before, atol=1e-8))
|
298
|
+
self.assertFalse(jnp.allclose(cell.c.value, c_before, atol=1e-8))
|
299
|
+
|
300
|
+
def test_forget_gate_bias(self):
|
301
|
+
"""Test that forget gate has positive bias initialization."""
|
302
|
+
cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
303
|
+
cell.init_state(batch_size=self.batch_size)
|
304
|
+
|
305
|
+
# Process with zero input to see bias effect
|
306
|
+
zero_input = jnp.zeros((self.batch_size, self.num_in))
|
307
|
+
xh = jnp.concatenate([zero_input, cell.h.value], axis=-1)
|
308
|
+
gates = cell.W(xh)
|
309
|
+
_, _, f, _ = jnp.split(gates, indices_or_sections=4, axis=-1)
|
310
|
+
f_gate = functional.sigmoid(f + 1.) # Note the +1 bias
|
311
|
+
|
312
|
+
# Forget gate should be biased towards remembering (> 0.5)
|
313
|
+
self.assertTrue(jnp.mean(f_gate) > 0.5)
|
314
|
+
|
315
|
+
def test_gate_values_range(self):
|
316
|
+
"""Test that all gates produce values in [0, 1]."""
|
317
|
+
cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
318
|
+
cell.init_state(batch_size=self.batch_size)
|
319
|
+
|
320
|
+
xh = jnp.concatenate([self.x, cell.h.value], axis=-1)
|
321
|
+
i, g, f, o = jnp.split(cell.W(xh), indices_or_sections=4, axis=-1)
|
322
|
+
|
323
|
+
i_gate = functional.sigmoid(i)
|
324
|
+
f_gate = functional.sigmoid(f + 1.)
|
325
|
+
o_gate = functional.sigmoid(o)
|
326
|
+
|
327
|
+
for gate in [i_gate, f_gate, o_gate]:
|
328
|
+
self.assertTrue(jnp.all(gate >= 0))
|
329
|
+
self.assertTrue(jnp.all(gate <= 1))
|
330
|
+
|
331
|
+
def test_cell_state_gradient_flow(self):
|
332
|
+
"""Test gradient flow through cell state."""
|
333
|
+
cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
334
|
+
cell.init_state(batch_size=self.batch_size)
|
335
|
+
|
336
|
+
def loss_fn(x):
|
337
|
+
for t in range(10):
|
338
|
+
_ = cell.update(x)
|
339
|
+
return jnp.mean(cell.c.value ** 2)
|
340
|
+
|
341
|
+
grad_fn = jax.grad(loss_fn)
|
342
|
+
grad = grad_fn(self.x)
|
343
|
+
|
344
|
+
self.assertFalse(jnp.any(jnp.isnan(grad)))
|
345
|
+
self.assertTrue(jnp.any(grad != 0))
|
346
|
+
|
347
|
+
|
348
|
+
class TestURLSTMCell(TestRNNCellBase):
|
349
|
+
"""Comprehensive tests for URLSTMCell."""
|
350
|
+
|
351
|
+
def test_basic_forward(self):
|
352
|
+
"""Test basic forward pass."""
|
353
|
+
cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
|
354
|
+
cell.init_state(batch_size=self.batch_size)
|
355
|
+
|
356
|
+
output = cell.update(self.x)
|
357
|
+
|
358
|
+
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
359
|
+
self.assertFalse(jnp.any(jnp.isnan(output)))
|
360
|
+
|
361
|
+
def test_untied_bias_mechanism(self):
|
362
|
+
"""Test the untied bias initialization."""
|
363
|
+
cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
|
364
|
+
cell.init_state(batch_size=self.batch_size)
|
365
|
+
|
366
|
+
# Check bias values are initialized
|
367
|
+
self.assertIsNotNone(cell.bias.value)
|
368
|
+
self.assertEqual(cell.bias.value.shape, (self.num_out,))
|
369
|
+
|
370
|
+
# Biases should be diverse (not all the same)
|
371
|
+
self.assertGreater(jnp.std(cell.bias.value), 0.1)
|
372
|
+
|
373
|
+
def test_unified_gate_computation(self):
|
374
|
+
"""Test the unified gate mechanism."""
|
375
|
+
cell = URLSTMCell(num_in=self.num_in, num_out=self.num_out)
|
376
|
+
cell.init_state(batch_size=self.batch_size)
|
377
|
+
|
378
|
+
h, c = cell.h.value, cell.c.value
|
379
|
+
xh = jnp.concatenate([self.x, h], axis=-1)
|
380
|
+
gates = cell.W(xh)
|
381
|
+
f, r, u, o = jnp.split(gates, indices_or_sections=4, axis=-1)
|
382
|
+
|
383
|
+
f_gate = functional.sigmoid(f + cell.bias.value)
|
384
|
+
r_gate = functional.sigmoid(r - cell.bias.value)
|
385
|
+
|
386
|
+
# Compute unified gate
|
387
|
+
g = 2 * r_gate * f_gate + (1 - 2 * r_gate) * f_gate ** 2
|
388
|
+
|
389
|
+
# Unified gate should be in [0, 1]
|
390
|
+
self.assertTrue(jnp.all(g >= 0))
|
391
|
+
self.assertTrue(jnp.all(g <= 1))
|
392
|
+
|
393
|
+
def test_comparison_with_lstm(self):
|
394
|
+
"""Test that URLSTM behaves differently from standard LSTM."""
|
395
|
+
urlstm = URLSTMCell(num_in=self.num_in, num_out=self.num_out, state_init=init.Constant(0.5))
|
396
|
+
lstm = LSTMCell(num_in=self.num_in, num_out=self.num_out, state_init=init.Constant(0.5))
|
397
|
+
|
398
|
+
urlstm.init_state(batch_size=self.batch_size)
|
399
|
+
lstm.init_state(batch_size=self.batch_size)
|
400
|
+
|
401
|
+
# Same input should produce different outputs
|
402
|
+
urlstm_out = urlstm.update(self.x)
|
403
|
+
lstm_out = lstm.update(self.x)
|
404
|
+
|
405
|
+
self.assertFalse(jnp.allclose(urlstm_out, lstm_out, atol=1e-4))
|
406
|
+
|
407
|
+
|
408
|
+
class TestRNNCellIntegration(TestRNNCellBase):
|
409
|
+
"""Integration tests for all RNN cells."""
|
410
|
+
|
411
|
+
def test_all_cells_compatible_interface(self):
|
412
|
+
"""Test that all cells have compatible interfaces."""
|
413
|
+
cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
|
414
|
+
|
415
|
+
for CellType in cell_types:
|
416
|
+
cell = CellType(num_in=self.num_in, num_out=self.num_out)
|
417
|
+
|
418
|
+
# Test init_state
|
419
|
+
cell.init_state(batch_size=self.batch_size)
|
420
|
+
|
421
|
+
# Test update
|
422
|
+
output = cell.update(self.x)
|
423
|
+
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
424
|
+
|
425
|
+
# Test reset_state
|
426
|
+
cell.reset_state(batch_size=16)
|
427
|
+
|
428
|
+
# Test with new batch size
|
429
|
+
x_small = jnp.ones((16, self.num_in))
|
430
|
+
output_small = cell.update(x_small)
|
431
|
+
self.assertEqual(output_small.shape, (16, self.num_out))
|
432
|
+
|
433
|
+
def test_sequence_to_sequence(self):
|
434
|
+
"""Test sequence-to-sequence processing."""
|
435
|
+
cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
|
436
|
+
|
437
|
+
for CellType in cell_types:
|
438
|
+
cell = CellType(num_in=self.num_in, num_out=self.num_out)
|
439
|
+
cell.init_state(batch_size=self.batch_size)
|
440
|
+
|
441
|
+
outputs = []
|
442
|
+
for t in range(self.sequence_length):
|
443
|
+
output = cell.update(self.sequence[t])
|
444
|
+
outputs.append(output)
|
445
|
+
|
446
|
+
outputs = jnp.stack(outputs)
|
447
|
+
self.assertEqual(
|
448
|
+
outputs.shape,
|
449
|
+
(self.sequence_length, self.batch_size, self.num_out)
|
450
|
+
)
|
451
|
+
|
452
|
+
def test_variable_length_sequences(self):
|
453
|
+
"""Test handling of variable length sequences with masking."""
|
454
|
+
cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
455
|
+
cell.init_state(batch_size=self.batch_size)
|
456
|
+
|
457
|
+
# Create mask for variable lengths
|
458
|
+
lengths = jnp.array([10, 20, 30, 40] * (self.batch_size // 4))
|
459
|
+
mask = jnp.arange(self.sequence_length)[:, None] < lengths[None, :]
|
460
|
+
|
461
|
+
outputs = []
|
462
|
+
for t in range(self.sequence_length):
|
463
|
+
output = cell.update(self.sequence[t])
|
464
|
+
# Apply mask
|
465
|
+
output = output * mask[t:t+1].T
|
466
|
+
outputs.append(output)
|
467
|
+
|
468
|
+
outputs = jnp.stack(outputs)
|
469
|
+
|
470
|
+
# Check that masked positions are zero
|
471
|
+
for b in range(self.batch_size):
|
472
|
+
length = lengths[b]
|
473
|
+
if length < self.sequence_length:
|
474
|
+
self.assertTrue(jnp.allclose(outputs[length:, b, :], 0.0))
|
475
|
+
|
476
|
+
def test_gradient_clipping(self):
|
477
|
+
"""Test gradient clipping during training."""
|
478
|
+
cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
|
479
|
+
cell.init_state(batch_size=self.batch_size)
|
480
|
+
|
481
|
+
def loss_fn(x):
|
482
|
+
output = jnp.zeros((self.batch_size, self.num_out))
|
483
|
+
for t in range(50): # Long sequence
|
484
|
+
output = cell.update(x * (t + 1)) # Amplify input
|
485
|
+
return jnp.mean(output ** 2)
|
486
|
+
|
487
|
+
grad_fn = jax.grad(loss_fn)
|
488
|
+
grad = grad_fn(self.x)
|
489
|
+
|
490
|
+
# Gradients should not explode
|
491
|
+
self.assertFalse(jnp.any(jnp.isnan(grad)))
|
492
|
+
self.assertFalse(jnp.any(jnp.isinf(grad)))
|
493
|
+
self.assertLess(jnp.max(jnp.abs(grad)), 1e6)
|
494
|
+
|
495
|
+
|
496
|
+
class TestRNNCellEdgeCases(TestRNNCellBase):
|
497
|
+
"""Edge case tests for RNN cells."""
|
498
|
+
|
499
|
+
def test_single_sample(self):
|
500
|
+
"""Test with batch size of 1."""
|
501
|
+
cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
|
502
|
+
cell.init_state(batch_size=1)
|
503
|
+
|
504
|
+
x = jnp.ones((1, self.num_in))
|
505
|
+
output = cell.update(x)
|
506
|
+
self.assertEqual(output.shape, (1, self.num_out))
|
507
|
+
|
508
|
+
def test_zero_input(self):
|
509
|
+
"""Test with zero inputs."""
|
510
|
+
cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
|
511
|
+
|
512
|
+
for CellType in cell_types:
|
513
|
+
cell = CellType(num_in=self.num_in, num_out=self.num_out)
|
514
|
+
cell.init_state(batch_size=self.batch_size)
|
515
|
+
|
516
|
+
zero_input = jnp.zeros((self.batch_size, self.num_in))
|
517
|
+
output = cell.update(zero_input)
|
518
|
+
|
519
|
+
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
520
|
+
self.assertFalse(jnp.any(jnp.isnan(output)))
|
521
|
+
|
522
|
+
def test_large_input_values(self):
|
523
|
+
"""Test with large input values."""
|
524
|
+
cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
525
|
+
cell.init_state(batch_size=self.batch_size)
|
526
|
+
|
527
|
+
large_input = jnp.ones((self.batch_size, self.num_in)) * 100
|
528
|
+
output = cell.update(large_input)
|
529
|
+
|
530
|
+
# Should handle large inputs gracefully (sigmoid saturation)
|
531
|
+
self.assertFalse(jnp.any(jnp.isnan(output)))
|
532
|
+
self.assertFalse(jnp.any(jnp.isinf(output)))
|
533
|
+
|
534
|
+
def test_very_long_sequence(self):
|
535
|
+
"""Test with very long sequences."""
|
536
|
+
cell = GRUCell(num_in=self.num_in, num_out=self.num_out)
|
537
|
+
cell.init_state(batch_size=4) # Smaller batch for memory
|
538
|
+
|
539
|
+
long_sequence = jnp.ones((1000, 4, self.num_in))
|
540
|
+
|
541
|
+
final_output = None
|
542
|
+
for t in range(1000):
|
543
|
+
final_output = cell.update(long_sequence[t])
|
544
|
+
|
545
|
+
# Should not have numerical issues even after long sequence
|
546
|
+
self.assertFalse(jnp.any(jnp.isnan(final_output)))
|
547
|
+
self.assertFalse(jnp.any(jnp.isinf(final_output)))
|
548
|
+
|
549
|
+
def test_dimension_mismatch_error(self):
|
550
|
+
"""Test that dimension mismatches raise appropriate errors."""
|
551
|
+
cell = ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
|
552
|
+
cell.init_state(batch_size=self.batch_size)
|
553
|
+
|
554
|
+
# Wrong input dimension should raise error
|
555
|
+
wrong_input = jnp.ones((self.batch_size, self.num_in + 5))
|
556
|
+
|
557
|
+
with self.assertRaises(Exception):
|
558
|
+
_ = cell.update(wrong_input)
|
559
|
+
|
560
|
+
|
561
|
+
class TestRNNCellProperties(TestRNNCellBase):
|
562
|
+
"""Test cell properties and attributes."""
|
563
|
+
|
564
|
+
def test_cell_attributes(self):
|
565
|
+
"""Test that cells have correct attributes."""
|
566
|
+
cell = LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
567
|
+
|
568
|
+
self.assertEqual(cell.num_in, self.num_in)
|
569
|
+
self.assertEqual(cell.num_out, self.num_out)
|
570
|
+
self.assertEqual(cell.in_size, (self.num_in,))
|
571
|
+
self.assertEqual(cell.out_size, (self.num_out,))
|
572
|
+
|
573
|
+
def test_inheritance_structure(self):
|
574
|
+
"""Test that all cells inherit from RNNCell."""
|
575
|
+
cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
|
576
|
+
|
577
|
+
for CellType in cell_types:
|
578
|
+
cell = CellType(num_in=self.num_in, num_out=self.num_out)
|
579
|
+
self.assertIsInstance(cell, RNNCell)
|
580
|
+
|
581
|
+
def test_docstring_presence(self):
|
582
|
+
"""Test that all cells have proper docstrings."""
|
583
|
+
cell_types = [ValinaRNNCell, GRUCell, MGUCell, LSTMCell, URLSTMCell]
|
584
|
+
|
585
|
+
for CellType in cell_types:
|
586
|
+
self.assertIsNotNone(CellType.__doc__)
|
587
|
+
self.assertIn("Examples", CellType.__doc__)
|
588
|
+
self.assertIn("Parameters", CellType.__doc__)
|
589
|
+
self.assertIn(">>>", CellType.__doc__)
|
590
|
+
|
591
|
+
|
592
|
+
if __name__ == '__main__':
|
593
593
|
unittest.main()
|