brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -58
- brainstate/_compatible_import.py +340 -148
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +45 -55
- brainstate/_state.py +1652 -1605
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -563
- brainstate/environ_test.py +1223 -62
- brainstate/graph/__init__.py +22 -29
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1433 -365
- brainstate/mixin_test.py +1017 -77
- brainstate/nn/__init__.py +137 -135
- brainstate/nn/_activations.py +1100 -808
- brainstate/nn/_activations_test.py +354 -331
- brainstate/nn/_collective_ops.py +633 -514
- brainstate/nn/_collective_ops_test.py +774 -43
- brainstate/nn/_common.py +226 -178
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +2010 -501
- brainstate/nn/_conv_test.py +849 -238
- brainstate/nn/_delay.py +575 -588
- brainstate/nn/_delay_test.py +243 -238
- brainstate/nn/_dropout.py +618 -426
- brainstate/nn/_dropout_test.py +477 -100
- brainstate/nn/_dynamics.py +1267 -1343
- brainstate/nn/_dynamics_test.py +67 -78
- brainstate/nn/_elementwise.py +1298 -1119
- brainstate/nn/_elementwise_test.py +830 -169
- brainstate/nn/_embedding.py +408 -58
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
- brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
- brainstate/nn/_exp_euler.py +254 -92
- brainstate/nn/_exp_euler_test.py +377 -35
- brainstate/nn/_linear.py +744 -424
- brainstate/nn/_linear_test.py +475 -107
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +384 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -975
- brainstate/nn/_normalizations_test.py +699 -73
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +2239 -1177
- brainstate/nn/_poolings_test.py +953 -217
- brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +216 -89
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +809 -553
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
- brainstate/random/__init__.py +270 -24
- brainstate/random/_rand_funs.py +3938 -3616
- brainstate/random/_rand_funs_test.py +640 -567
- brainstate/random/_rand_seed.py +675 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1409
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
- brainstate/{augment → transform}/_autograd.py +1025 -778
- brainstate/{augment → transform}/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +220 -220
- brainstate/{compile → transform}/_error_if.py +94 -92
- brainstate/{compile → transform}/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +38 -38
- brainstate/{compile → transform}/_jit.py +399 -346
- brainstate/{compile → transform}/_jit_test.py +143 -143
- brainstate/{compile → transform}/_loop_collect_return.py +675 -536
- brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
- brainstate/{compile → transform}/_loop_no_collection.py +283 -184
- brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +255 -202
- brainstate/{augment → transform}/_random.py +171 -151
- brainstate/{compile → transform}/_unvmap.py +256 -159
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +837 -304
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +27 -50
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +945 -469
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +910 -523
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/mixin_test.py
CHANGED
@@ -1,77 +1,1017 @@
|
|
1
|
-
# Copyright 2024
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
def
|
38
|
-
|
39
|
-
self.assertTrue(
|
40
|
-
self.assertTrue(
|
41
|
-
self.assertTrue(
|
42
|
-
|
43
|
-
self.assertTrue(
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
self.assertTrue(
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
self.assertFalse(
|
54
|
-
self.assertFalse(
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
self.
|
65
|
-
|
66
|
-
def
|
67
|
-
a
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""
|
17
|
+
Comprehensive tests for brainstate.mixin module.
|
18
|
+
|
19
|
+
This test suite covers all functionality in the mixin module including:
|
20
|
+
- Base mixin classes
|
21
|
+
- Parameter description and deferred instantiation
|
22
|
+
- Type utilities (JointTypes, OneOfTypes)
|
23
|
+
- Mode system (Mode, JointMode, Training, Batching)
|
24
|
+
- Helper utilities (hashable, not_implemented, etc.)
|
25
|
+
"""
|
26
|
+
|
27
|
+
import unittest
|
28
|
+
|
29
|
+
import jax.numpy as jnp
|
30
|
+
|
31
|
+
import brainstate
|
32
|
+
|
33
|
+
|
34
|
+
class TestHashableFunction(unittest.TestCase):
|
35
|
+
"""Test the hashable utility function."""
|
36
|
+
|
37
|
+
def test_hashable_primitives(self):
|
38
|
+
"""Test hashable with primitive types."""
|
39
|
+
self.assertTrue(brainstate.mixin.hashable(42))
|
40
|
+
self.assertTrue(brainstate.mixin.hashable(3.14))
|
41
|
+
self.assertTrue(brainstate.mixin.hashable("string"))
|
42
|
+
self.assertTrue(brainstate.mixin.hashable(True))
|
43
|
+
self.assertTrue(brainstate.mixin.hashable(None))
|
44
|
+
|
45
|
+
def test_hashable_tuples(self):
|
46
|
+
"""Test hashable with tuples."""
|
47
|
+
self.assertTrue(brainstate.mixin.hashable((1, 2, 3)))
|
48
|
+
self.assertTrue(brainstate.mixin.hashable(("a", "b")))
|
49
|
+
self.assertTrue(brainstate.mixin.hashable(()))
|
50
|
+
|
51
|
+
def test_non_hashable_types(self):
|
52
|
+
"""Test non-hashable types."""
|
53
|
+
self.assertFalse(brainstate.mixin.hashable([1, 2, 3]))
|
54
|
+
self.assertFalse(brainstate.mixin.hashable({"key": "value"}))
|
55
|
+
self.assertFalse(brainstate.mixin.hashable({1, 2, 3}))
|
56
|
+
self.assertFalse(brainstate.mixin.hashable(jnp.array([1, 2, 3])))
|
57
|
+
|
58
|
+
|
59
|
+
class TestMixin(unittest.TestCase):
|
60
|
+
"""Test the base Mixin class."""
|
61
|
+
|
62
|
+
def test_mixin_exists(self):
|
63
|
+
"""Test that Mixin class exists."""
|
64
|
+
self.assertTrue(brainstate.mixin.Mixin)
|
65
|
+
|
66
|
+
def test_mixin_inheritance(self):
|
67
|
+
"""Test creating a custom mixin."""
|
68
|
+
|
69
|
+
class LoggingMixin(brainstate.mixin.Mixin):
|
70
|
+
def log(self, message):
|
71
|
+
return f"[LOG] {message}"
|
72
|
+
|
73
|
+
class Component(LoggingMixin):
|
74
|
+
pass
|
75
|
+
|
76
|
+
comp = Component()
|
77
|
+
self.assertEqual(comp.log("test"), "[LOG] test")
|
78
|
+
|
79
|
+
def test_mixin_multiple_inheritance(self):
|
80
|
+
"""Test multiple mixin inheritance."""
|
81
|
+
|
82
|
+
class MixinA(brainstate.mixin.Mixin):
|
83
|
+
def method_a(self):
|
84
|
+
return "A"
|
85
|
+
|
86
|
+
class MixinB(brainstate.mixin.Mixin):
|
87
|
+
def method_b(self):
|
88
|
+
return "B"
|
89
|
+
|
90
|
+
class Component(MixinA, MixinB):
|
91
|
+
pass
|
92
|
+
|
93
|
+
comp = Component()
|
94
|
+
self.assertEqual(comp.method_a(), "A")
|
95
|
+
self.assertEqual(comp.method_b(), "B")
|
96
|
+
|
97
|
+
|
98
|
+
class TestParamDesc(unittest.TestCase):
|
99
|
+
"""Test ParamDesc mixin and ParamDescriber."""
|
100
|
+
|
101
|
+
def test_param_desc_basic(self):
|
102
|
+
"""Test basic ParamDesc functionality."""
|
103
|
+
|
104
|
+
class Network(brainstate.mixin.ParamDesc):
|
105
|
+
def __init__(self, size, learning_rate=0.01):
|
106
|
+
self.size = size
|
107
|
+
self.learning_rate = learning_rate
|
108
|
+
|
109
|
+
# Test desc method exists
|
110
|
+
self.assertTrue(hasattr(Network, 'desc'))
|
111
|
+
|
112
|
+
# Create a descriptor
|
113
|
+
desc = Network.desc(size=100)
|
114
|
+
self.assertIsInstance(desc, brainstate.mixin.ParamDescriber)
|
115
|
+
|
116
|
+
def test_param_describer_instantiation(self):
|
117
|
+
"""Test ParamDescriber can create instances."""
|
118
|
+
|
119
|
+
class Network(brainstate.mixin.ParamDesc):
|
120
|
+
def __init__(self, size, learning_rate=0.01):
|
121
|
+
self.size = size
|
122
|
+
self.learning_rate = learning_rate
|
123
|
+
|
124
|
+
desc = Network.desc(size=100, learning_rate=0.001)
|
125
|
+
|
126
|
+
# Create instances
|
127
|
+
net1 = desc()
|
128
|
+
self.assertEqual(net1.size, 100)
|
129
|
+
self.assertEqual(net1.learning_rate, 0.001)
|
130
|
+
|
131
|
+
# Create with overrides
|
132
|
+
net2 = desc(learning_rate=0.005)
|
133
|
+
self.assertEqual(net2.size, 100)
|
134
|
+
self.assertEqual(net2.learning_rate, 0.005)
|
135
|
+
|
136
|
+
def test_param_describer_init_method(self):
|
137
|
+
"""Test ParamDescriber.init() method."""
|
138
|
+
|
139
|
+
class Model(brainstate.mixin.ParamDesc):
|
140
|
+
def __init__(self, value):
|
141
|
+
self.value = value
|
142
|
+
|
143
|
+
desc = Model.desc(value=42)
|
144
|
+
instance = desc.init()
|
145
|
+
self.assertEqual(instance.value, 42)
|
146
|
+
|
147
|
+
def test_param_describer_identifier(self):
|
148
|
+
"""Test ParamDescriber identifier property."""
|
149
|
+
|
150
|
+
class Model(brainstate.mixin.ParamDesc):
|
151
|
+
def __init__(self, x, y=10):
|
152
|
+
self.x = x
|
153
|
+
self.y = y
|
154
|
+
|
155
|
+
desc = Model.desc(x=5, y=20)
|
156
|
+
identifier = desc.identifier
|
157
|
+
|
158
|
+
# Identifier should be a tuple
|
159
|
+
self.assertIsInstance(identifier, tuple)
|
160
|
+
self.assertEqual(len(identifier), 3)
|
161
|
+
self.assertEqual(identifier[0], Model)
|
162
|
+
|
163
|
+
# Identifier should be read-only
|
164
|
+
with self.assertRaises(AttributeError):
|
165
|
+
desc.identifier = "new"
|
166
|
+
|
167
|
+
def test_param_describer_class_getitem(self):
|
168
|
+
"""Test ParamDescriber[Class] notation."""
|
169
|
+
|
170
|
+
class Model:
|
171
|
+
def __init__(self, value):
|
172
|
+
self.value = value
|
173
|
+
|
174
|
+
desc = brainstate.mixin.ParamDescriber[Model]
|
175
|
+
self.assertIsInstance(desc, brainstate.mixin.ParamDescriber)
|
176
|
+
self.assertEqual(desc.cls, Model)
|
177
|
+
|
178
|
+
def test_no_subclass_meta(self):
|
179
|
+
"""Test that ParamDescriber cannot be subclassed."""
|
180
|
+
|
181
|
+
with self.assertRaises(TypeError):
|
182
|
+
class CustomDescriber(brainstate.mixin.ParamDescriber):
|
183
|
+
pass
|
184
|
+
|
185
|
+
|
186
|
+
class TestHashableDict(unittest.TestCase):
|
187
|
+
"""Test HashableDict class."""
|
188
|
+
|
189
|
+
def test_hashable_dict_basic(self):
|
190
|
+
"""Test basic HashableDict functionality."""
|
191
|
+
d = brainstate.mixin.HashableDict({"a": 1, "b": 2})
|
192
|
+
h = hash(d)
|
193
|
+
self.assertIsInstance(h, int)
|
194
|
+
|
195
|
+
def test_hashable_dict_with_arrays(self):
|
196
|
+
"""Test HashableDict with non-hashable values."""
|
197
|
+
d = brainstate.mixin.HashableDict({
|
198
|
+
"array": jnp.array([1, 2, 3]),
|
199
|
+
"value": 42
|
200
|
+
})
|
201
|
+
h = hash(d)
|
202
|
+
self.assertIsInstance(h, int)
|
203
|
+
|
204
|
+
def test_hashable_dict_consistency(self):
|
205
|
+
"""Test that equal dicts have equal hashes."""
|
206
|
+
d1 = brainstate.mixin.HashableDict({"a": 1, "b": 2})
|
207
|
+
d2 = brainstate.mixin.HashableDict({"b": 2, "a": 1})
|
208
|
+
self.assertEqual(hash(d1), hash(d2))
|
209
|
+
|
210
|
+
def test_hashable_dict_usable_as_key(self):
|
211
|
+
"""Test that HashableDict can be used as dict key."""
|
212
|
+
d = brainstate.mixin.HashableDict({"x": 10})
|
213
|
+
cache = {d: "result"}
|
214
|
+
self.assertEqual(cache[d], "result")
|
215
|
+
|
216
|
+
|
217
|
+
class TestJointTypes(unittest.TestCase):
|
218
|
+
"""Test JointTypes functionality."""
|
219
|
+
|
220
|
+
def test_joint_types_basic(self):
|
221
|
+
"""Test basic JointTypes creation."""
|
222
|
+
|
223
|
+
class A:
|
224
|
+
pass
|
225
|
+
|
226
|
+
class B:
|
227
|
+
pass
|
228
|
+
|
229
|
+
JointAB = brainstate.mixin.JointTypes(A, B)
|
230
|
+
self.assertIsNotNone(JointAB)
|
231
|
+
|
232
|
+
def test_joint_types_isinstance(self):
|
233
|
+
"""Test isinstance with JointTypes."""
|
234
|
+
|
235
|
+
class Serializable:
|
236
|
+
def save(self):
|
237
|
+
pass
|
238
|
+
|
239
|
+
class Visualizable:
|
240
|
+
def plot(self):
|
241
|
+
pass
|
242
|
+
|
243
|
+
Combined = brainstate.mixin.JointTypes(Serializable, Visualizable)
|
244
|
+
|
245
|
+
class Model(Serializable, Visualizable):
|
246
|
+
def save(self):
|
247
|
+
return "saved"
|
248
|
+
|
249
|
+
def plot(self):
|
250
|
+
return "plotted"
|
251
|
+
|
252
|
+
model = Model()
|
253
|
+
self.assertTrue(isinstance(model, Combined))
|
254
|
+
|
255
|
+
def test_joint_types_issubclass(self):
|
256
|
+
"""Test issubclass with JointTypes."""
|
257
|
+
|
258
|
+
class A:
|
259
|
+
pass
|
260
|
+
|
261
|
+
class B:
|
262
|
+
pass
|
263
|
+
|
264
|
+
JointAB = brainstate.mixin.JointTypes(A, B)
|
265
|
+
|
266
|
+
class C(A, B):
|
267
|
+
pass
|
268
|
+
|
269
|
+
self.assertTrue(issubclass(C, JointAB))
|
270
|
+
|
271
|
+
def test_joint_types_single_type(self):
|
272
|
+
"""Test JointTypes with single type returns that type."""
|
273
|
+
|
274
|
+
class A:
|
275
|
+
pass
|
276
|
+
|
277
|
+
result = brainstate.mixin.JointTypes(A)
|
278
|
+
self.assertEqual(result, A)
|
279
|
+
|
280
|
+
def test_joint_types_no_types(self):
|
281
|
+
"""Test JointTypes with no types raises error."""
|
282
|
+
with self.assertRaises(TypeError):
|
283
|
+
brainstate.mixin.JointTypes()
|
284
|
+
|
285
|
+
def test_joint_types_removes_duplicates(self):
|
286
|
+
"""Test that JointTypes removes duplicate types."""
|
287
|
+
|
288
|
+
class A:
|
289
|
+
pass
|
290
|
+
|
291
|
+
# Should handle duplicates gracefully
|
292
|
+
JointA = brainstate.mixin.JointTypes(A, A, A)
|
293
|
+
self.assertEqual(JointA, A)
|
294
|
+
|
295
|
+
|
296
|
+
class TestOneOfTypes(unittest.TestCase):
|
297
|
+
"""Test OneOfTypes functionality."""
|
298
|
+
|
299
|
+
def test_one_of_types_basic(self):
|
300
|
+
"""Test basic OneOfTypes creation."""
|
301
|
+
IntOrFloat = brainstate.mixin.OneOfTypes(int, float)
|
302
|
+
self.assertIsNotNone(IntOrFloat)
|
303
|
+
|
304
|
+
def test_one_of_types_isinstance(self):
|
305
|
+
"""Test isinstance with OneOfTypes."""
|
306
|
+
NumType = brainstate.mixin.OneOfTypes(int, float)
|
307
|
+
|
308
|
+
self.assertTrue(isinstance(42, NumType))
|
309
|
+
self.assertTrue(isinstance(3.14, NumType))
|
310
|
+
self.assertFalse(isinstance("hello", NumType))
|
311
|
+
|
312
|
+
def test_one_of_types_single_type(self):
|
313
|
+
"""Test OneOfTypes with single type returns that type."""
|
314
|
+
result = brainstate.mixin.OneOfTypes(int)
|
315
|
+
self.assertEqual(result, int)
|
316
|
+
|
317
|
+
def test_one_of_types_no_types(self):
|
318
|
+
"""Test OneOfTypes with no types raises error."""
|
319
|
+
with self.assertRaises(TypeError):
|
320
|
+
brainstate.mixin.OneOfTypes()
|
321
|
+
|
322
|
+
def test_one_of_types_with_none(self):
|
323
|
+
"""Test OneOfTypes with None for optional types."""
|
324
|
+
MaybeInt = brainstate.mixin.OneOfTypes(int, type(None))
|
325
|
+
|
326
|
+
self.assertTrue(isinstance(42, MaybeInt))
|
327
|
+
self.assertTrue(isinstance(None, MaybeInt))
|
328
|
+
self.assertFalse(isinstance("hello", MaybeInt))
|
329
|
+
|
330
|
+
|
331
|
+
|
332
|
+
class TestNotImplemented(unittest.TestCase):
|
333
|
+
"""Test not_implemented decorator."""
|
334
|
+
|
335
|
+
def test_not_implemented_decorator(self):
|
336
|
+
"""Test not_implemented decorator marks functions."""
|
337
|
+
|
338
|
+
@brainstate.mixin.not_implemented
|
339
|
+
def my_function():
|
340
|
+
pass
|
341
|
+
|
342
|
+
self.assertTrue(hasattr(my_function, 'not_implemented'))
|
343
|
+
self.assertTrue(my_function.not_implemented)
|
344
|
+
|
345
|
+
def test_not_implemented_raises(self):
|
346
|
+
"""Test not_implemented decorator raises error when called."""
|
347
|
+
|
348
|
+
@brainstate.mixin.not_implemented
|
349
|
+
def my_function():
|
350
|
+
pass
|
351
|
+
|
352
|
+
with self.assertRaises(NotImplementedError) as cm:
|
353
|
+
my_function()
|
354
|
+
|
355
|
+
self.assertIn("my_function", str(cm.exception))
|
356
|
+
|
357
|
+
|
358
|
+
class TestMode(unittest.TestCase):
|
359
|
+
"""Test Mode base class."""
|
360
|
+
|
361
|
+
def test_mode_creation(self):
|
362
|
+
"""Test basic Mode creation."""
|
363
|
+
mode = brainstate.mixin.Mode()
|
364
|
+
self.assertIsNotNone(mode)
|
365
|
+
|
366
|
+
def test_mode_repr(self):
|
367
|
+
"""Test Mode string representation."""
|
368
|
+
mode = brainstate.mixin.Mode()
|
369
|
+
self.assertEqual(repr(mode), "Mode")
|
370
|
+
|
371
|
+
def test_mode_equality(self):
|
372
|
+
"""Test Mode equality comparison."""
|
373
|
+
mode1 = brainstate.mixin.Mode()
|
374
|
+
mode2 = brainstate.mixin.Mode()
|
375
|
+
self.assertEqual(mode1, mode2)
|
376
|
+
|
377
|
+
def test_mode_is_a(self):
|
378
|
+
"""Test Mode.is_a() method."""
|
379
|
+
mode = brainstate.mixin.Mode()
|
380
|
+
self.assertTrue(mode.is_a(brainstate.mixin.Mode))
|
381
|
+
self.assertFalse(mode.is_a(brainstate.mixin.Training))
|
382
|
+
|
383
|
+
def test_mode_has(self):
|
384
|
+
"""Test Mode.has() method."""
|
385
|
+
mode = brainstate.mixin.Mode()
|
386
|
+
self.assertTrue(mode.has(brainstate.mixin.Mode))
|
387
|
+
self.assertFalse(mode.has(brainstate.mixin.Training))
|
388
|
+
|
389
|
+
def test_custom_mode(self):
|
390
|
+
"""Test creating custom mode."""
|
391
|
+
|
392
|
+
class CustomMode(brainstate.mixin.Mode):
|
393
|
+
def __init__(self, value):
|
394
|
+
self.value = value
|
395
|
+
|
396
|
+
mode = CustomMode(42)
|
397
|
+
self.assertEqual(mode.value, 42)
|
398
|
+
self.assertTrue(mode.has(brainstate.mixin.Mode))
|
399
|
+
|
400
|
+
|
401
|
+
class TestTraining(unittest.TestCase):
|
402
|
+
"""Test Training mode."""
|
403
|
+
|
404
|
+
def test_training_creation(self):
|
405
|
+
"""Test Training mode creation."""
|
406
|
+
training = brainstate.mixin.Training()
|
407
|
+
self.assertIsNotNone(training)
|
408
|
+
|
409
|
+
def test_training_is_mode(self):
|
410
|
+
"""Test Training is a Mode."""
|
411
|
+
training = brainstate.mixin.Training()
|
412
|
+
self.assertTrue(training.has(brainstate.mixin.Mode))
|
413
|
+
|
414
|
+
def test_training_is_a(self):
|
415
|
+
"""Test Training.is_a() method."""
|
416
|
+
training = brainstate.mixin.Training()
|
417
|
+
self.assertTrue(training.is_a(brainstate.mixin.Training))
|
418
|
+
self.assertFalse(training.is_a(brainstate.mixin.Batching))
|
419
|
+
|
420
|
+
def test_training_has(self):
|
421
|
+
"""Test Training.has() method."""
|
422
|
+
training = brainstate.mixin.Training()
|
423
|
+
self.assertTrue(training.has(brainstate.mixin.Training))
|
424
|
+
self.assertFalse(training.has(brainstate.mixin.Batching))
|
425
|
+
|
426
|
+
def test_training_joint_types(self):
|
427
|
+
"""Test Training with JointTypes."""
|
428
|
+
training = brainstate.mixin.Training()
|
429
|
+
self.assertTrue(training.is_a(brainstate.mixin.JointTypes(brainstate.mixin.Training)))
|
430
|
+
self.assertTrue(training.has(brainstate.mixin.JointTypes(brainstate.mixin.Training)))
|
431
|
+
|
432
|
+
|
433
|
+
class TestBatching(unittest.TestCase):
|
434
|
+
"""Test Batching mode."""
|
435
|
+
|
436
|
+
def test_batching_creation(self):
|
437
|
+
"""Test Batching mode creation."""
|
438
|
+
batching = brainstate.mixin.Batching()
|
439
|
+
self.assertIsNotNone(batching)
|
440
|
+
|
441
|
+
def test_batching_default_params(self):
|
442
|
+
"""Test Batching default parameters."""
|
443
|
+
batching = brainstate.mixin.Batching()
|
444
|
+
self.assertEqual(batching.batch_size, 1)
|
445
|
+
self.assertEqual(batching.batch_axis, 0)
|
446
|
+
|
447
|
+
def test_batching_custom_params(self):
|
448
|
+
"""Test Batching with custom parameters."""
|
449
|
+
batching = brainstate.mixin.Batching(batch_size=32, batch_axis=1)
|
450
|
+
self.assertEqual(batching.batch_size, 32)
|
451
|
+
self.assertEqual(batching.batch_axis, 1)
|
452
|
+
|
453
|
+
def test_batching_repr(self):
|
454
|
+
"""Test Batching string representation."""
|
455
|
+
batching = brainstate.mixin.Batching(batch_size=64, batch_axis=0)
|
456
|
+
self.assertIn("64", repr(batching))
|
457
|
+
self.assertIn("0", repr(batching))
|
458
|
+
|
459
|
+
def test_batching_is_mode(self):
|
460
|
+
"""Test Batching is a Mode."""
|
461
|
+
batching = brainstate.mixin.Batching()
|
462
|
+
self.assertTrue(batching.has(brainstate.mixin.Mode))
|
463
|
+
|
464
|
+
def test_batching_is_a(self):
|
465
|
+
"""Test Batching.is_a() method."""
|
466
|
+
batching = brainstate.mixin.Batching()
|
467
|
+
self.assertTrue(batching.is_a(brainstate.mixin.Batching))
|
468
|
+
self.assertFalse(batching.is_a(brainstate.mixin.Training))
|
469
|
+
|
470
|
+
def test_batching_has(self):
|
471
|
+
"""Test Batching.has() method."""
|
472
|
+
batching = brainstate.mixin.Batching()
|
473
|
+
self.assertTrue(batching.has(brainstate.mixin.Batching))
|
474
|
+
self.assertFalse(batching.has(brainstate.mixin.Training))
|
475
|
+
|
476
|
+
|
477
|
+
class TestJointMode(unittest.TestCase):
|
478
|
+
"""Test JointMode functionality."""
|
479
|
+
|
480
|
+
def test_joint_mode_creation(self):
|
481
|
+
"""Test JointMode creation."""
|
482
|
+
training = brainstate.mixin.Training()
|
483
|
+
batching = brainstate.mixin.Batching()
|
484
|
+
joint = brainstate.mixin.JointMode(training, batching)
|
485
|
+
self.assertIsNotNone(joint)
|
486
|
+
|
487
|
+
def test_joint_mode_repr(self):
|
488
|
+
"""Test JointMode string representation."""
|
489
|
+
training = brainstate.mixin.Training()
|
490
|
+
batching = brainstate.mixin.Batching(batch_size=32)
|
491
|
+
joint = brainstate.mixin.JointMode(training, batching)
|
492
|
+
|
493
|
+
repr_str = repr(joint)
|
494
|
+
self.assertIn("JointMode", repr_str)
|
495
|
+
self.assertIn("Training", repr_str)
|
496
|
+
self.assertIn("Batching", repr_str)
|
497
|
+
|
498
|
+
def test_joint_mode_has(self):
|
499
|
+
"""Test JointMode.has() method."""
|
500
|
+
training = brainstate.mixin.Training()
|
501
|
+
batching = brainstate.mixin.Batching()
|
502
|
+
joint = brainstate.mixin.JointMode(training, batching)
|
503
|
+
|
504
|
+
self.assertTrue(joint.has(brainstate.mixin.Training))
|
505
|
+
self.assertTrue(joint.has(brainstate.mixin.Batching))
|
506
|
+
self.assertTrue(joint.has(brainstate.mixin.Mode))
|
507
|
+
|
508
|
+
def test_joint_mode_is_a(self):
|
509
|
+
"""Test JointMode.is_a() method."""
|
510
|
+
training = brainstate.mixin.Training()
|
511
|
+
batching = brainstate.mixin.Batching()
|
512
|
+
joint = brainstate.mixin.JointMode(training, batching)
|
513
|
+
|
514
|
+
# JointMode.is_a() works by checking if the JointTypes of the mode types
|
515
|
+
# matches the expected type. This is a complex comparison.
|
516
|
+
# For practical use, test that it correctly identifies single types
|
517
|
+
self.assertFalse(joint.is_a(brainstate.mixin.Training)) # Not just Training
|
518
|
+
self.assertFalse(joint.is_a(brainstate.mixin.Batching)) # Not just Batching
|
519
|
+
|
520
|
+
# But a single mode joint should match
|
521
|
+
single_joint = brainstate.mixin.JointMode(training)
|
522
|
+
self.assertTrue(single_joint.is_a(brainstate.mixin.Training))
|
523
|
+
|
524
|
+
def test_joint_mode_single_mode(self):
|
525
|
+
"""Test JointMode with single mode."""
|
526
|
+
batching = brainstate.mixin.Batching()
|
527
|
+
joint = brainstate.mixin.JointMode(batching)
|
528
|
+
|
529
|
+
self.assertTrue(joint.has(brainstate.mixin.Batching))
|
530
|
+
self.assertTrue(joint.is_a(brainstate.mixin.Batching))
|
531
|
+
|
532
|
+
def test_joint_mode_attribute_access(self):
|
533
|
+
"""Test JointMode attribute delegation."""
|
534
|
+
batching = brainstate.mixin.Batching(batch_size=32, batch_axis=1)
|
535
|
+
training = brainstate.mixin.Training()
|
536
|
+
joint = brainstate.mixin.JointMode(batching, training)
|
537
|
+
|
538
|
+
# Should access batching attributes
|
539
|
+
self.assertEqual(joint.batch_size, 32)
|
540
|
+
self.assertEqual(joint.batch_axis, 1)
|
541
|
+
|
542
|
+
def test_joint_mode_invalid_type(self):
|
543
|
+
"""Test JointMode with non-Mode raises error."""
|
544
|
+
with self.assertRaises(TypeError):
|
545
|
+
brainstate.mixin.JointMode("not a mode")
|
546
|
+
|
547
|
+
def test_joint_mode_modes_attribute(self):
|
548
|
+
"""Test accessing modes attribute."""
|
549
|
+
training = brainstate.mixin.Training()
|
550
|
+
batching = brainstate.mixin.Batching()
|
551
|
+
joint = brainstate.mixin.JointMode(training, batching)
|
552
|
+
|
553
|
+
self.assertEqual(len(joint.modes), 2)
|
554
|
+
self.assertIn(training, joint.modes)
|
555
|
+
self.assertIn(batching, joint.modes)
|
556
|
+
|
557
|
+
def test_joint_mode_types_attribute(self):
|
558
|
+
"""Test accessing types attribute."""
|
559
|
+
training = brainstate.mixin.Training()
|
560
|
+
batching = brainstate.mixin.Batching()
|
561
|
+
joint = brainstate.mixin.JointMode(training, batching)
|
562
|
+
|
563
|
+
self.assertEqual(len(joint.types), 2)
|
564
|
+
self.assertIn(brainstate.mixin.Training, joint.types)
|
565
|
+
self.assertIn(brainstate.mixin.Batching, joint.types)
|
566
|
+
|
567
|
+
|
568
|
+
class TestIntegration(unittest.TestCase):
|
569
|
+
"""Integration tests combining multiple features."""
|
570
|
+
|
571
|
+
def test_param_desc_with_modes(self):
|
572
|
+
"""Test ParamDesc with Mode system."""
|
573
|
+
|
574
|
+
class Model(brainstate.mixin.ParamDesc):
|
575
|
+
def __init__(self, size, mode=None):
|
576
|
+
self.size = size
|
577
|
+
self.mode = mode if mode is not None else brainstate.mixin.Mode()
|
578
|
+
|
579
|
+
# Create descriptor with training mode
|
580
|
+
train_model_desc = Model.desc(size=100, mode=brainstate.mixin.Training())
|
581
|
+
model = train_model_desc()
|
582
|
+
|
583
|
+
self.assertEqual(model.size, 100)
|
584
|
+
self.assertTrue(model.mode.has(brainstate.mixin.Training))
|
585
|
+
|
586
|
+
def test_joint_types_with_multiple_mixins(self):
|
587
|
+
"""Test JointTypes with multiple mixin classes."""
|
588
|
+
|
589
|
+
class Serializable(brainstate.mixin.Mixin):
|
590
|
+
def save(self):
|
591
|
+
return "saved"
|
592
|
+
|
593
|
+
class Trainable(brainstate.mixin.Mixin):
|
594
|
+
def train(self):
|
595
|
+
return "trained"
|
596
|
+
|
597
|
+
class Evaluable(brainstate.mixin.Mixin):
|
598
|
+
def evaluate(self):
|
599
|
+
return "evaluated"
|
600
|
+
|
601
|
+
FullModel = brainstate.mixin.JointTypes(Serializable, Trainable, Evaluable)
|
602
|
+
|
603
|
+
class MyModel(Serializable, Trainable, Evaluable):
|
604
|
+
pass
|
605
|
+
|
606
|
+
model = MyModel()
|
607
|
+
self.assertTrue(isinstance(model, FullModel))
|
608
|
+
self.assertEqual(model.save(), "saved")
|
609
|
+
self.assertEqual(model.train(), "trained")
|
610
|
+
self.assertEqual(model.evaluate(), "evaluated")
|
611
|
+
|
612
|
+
def test_complex_mode_scenario(self):
|
613
|
+
"""Test complex scenario with multiple modes."""
|
614
|
+
|
615
|
+
class NeuralNetwork:
|
616
|
+
def __init__(self):
|
617
|
+
self.mode = None
|
618
|
+
|
619
|
+
def set_mode(self, mode):
|
620
|
+
self.mode = mode
|
621
|
+
|
622
|
+
def forward(self, x):
|
623
|
+
if self.mode is None:
|
624
|
+
return x
|
625
|
+
|
626
|
+
if self.mode.has(brainstate.mixin.Training):
|
627
|
+
# Add noise during training
|
628
|
+
x = x + 0.1
|
629
|
+
|
630
|
+
if self.mode.has(brainstate.mixin.Batching):
|
631
|
+
# Process in batches
|
632
|
+
batch_size = self.mode.batch_size
|
633
|
+
# Just return with batch info for testing
|
634
|
+
return x, batch_size
|
635
|
+
|
636
|
+
return x
|
637
|
+
|
638
|
+
net = NeuralNetwork()
|
639
|
+
|
640
|
+
# Test evaluation mode
|
641
|
+
result = net.forward(1.0)
|
642
|
+
self.assertEqual(result, 1.0)
|
643
|
+
|
644
|
+
# Test training mode
|
645
|
+
net.set_mode(brainstate.mixin.Training())
|
646
|
+
result = net.forward(1.0)
|
647
|
+
self.assertAlmostEqual(result, 1.1)
|
648
|
+
|
649
|
+
# Test joint mode
|
650
|
+
training = brainstate.mixin.Training()
|
651
|
+
batching = brainstate.mixin.Batching(batch_size=32)
|
652
|
+
net.set_mode(brainstate.mixin.JointMode(training, batching))
|
653
|
+
|
654
|
+
result, batch_size = net.forward(1.0)
|
655
|
+
self.assertAlmostEqual(result, 1.1)
|
656
|
+
self.assertEqual(batch_size, 32)
|
657
|
+
|
658
|
+
|
659
|
+
class TestJointTypesComprehensive(unittest.TestCase):
|
660
|
+
"""Comprehensive tests for JointTypes special methods and functionality."""
|
661
|
+
|
662
|
+
def setUp(self):
|
663
|
+
"""Set up test classes."""
|
664
|
+
class A:
|
665
|
+
pass
|
666
|
+
|
667
|
+
class B:
|
668
|
+
pass
|
669
|
+
|
670
|
+
class C:
|
671
|
+
pass
|
672
|
+
|
673
|
+
self.A = A
|
674
|
+
self.B = B
|
675
|
+
self.C = C
|
676
|
+
|
677
|
+
def test_repr(self):
|
678
|
+
"""Test __repr__ method."""
|
679
|
+
JT = brainstate.mixin.JointTypes[self.A, self.B]
|
680
|
+
repr_str = repr(JT)
|
681
|
+
self.assertIn('JointTypes', repr_str)
|
682
|
+
self.assertIn('A', repr_str)
|
683
|
+
self.assertIn('B', repr_str)
|
684
|
+
|
685
|
+
def test_eq_same_order(self):
|
686
|
+
"""Test equality with same type order."""
|
687
|
+
JT1 = brainstate.mixin.JointTypes[self.A, self.B]
|
688
|
+
JT2 = brainstate.mixin.JointTypes[self.A, self.B]
|
689
|
+
self.assertEqual(JT1, JT2)
|
690
|
+
|
691
|
+
def test_eq_different_order(self):
|
692
|
+
"""Test equality with different type order."""
|
693
|
+
JT1 = brainstate.mixin.JointTypes[self.A, self.B]
|
694
|
+
JT2 = brainstate.mixin.JointTypes[self.B, self.A]
|
695
|
+
self.assertEqual(JT1, JT2)
|
696
|
+
|
697
|
+
def test_eq_different_types(self):
|
698
|
+
"""Test inequality with different types."""
|
699
|
+
JT1 = brainstate.mixin.JointTypes[self.A, self.B]
|
700
|
+
JT2 = brainstate.mixin.JointTypes[self.A, self.C]
|
701
|
+
self.assertNotEqual(JT1, JT2)
|
702
|
+
|
703
|
+
def test_eq_with_non_jointtypes(self):
|
704
|
+
"""Test equality with non-JointTypes object."""
|
705
|
+
JT = brainstate.mixin.JointTypes[self.A, self.B]
|
706
|
+
self.assertNotEqual(JT, "not a type")
|
707
|
+
self.assertNotEqual(JT, 42)
|
708
|
+
self.assertNotEqual(JT, self.A)
|
709
|
+
|
710
|
+
def test_hash_consistency(self):
|
711
|
+
"""Test hash consistency."""
|
712
|
+
JT = brainstate.mixin.JointTypes[self.A, self.B]
|
713
|
+
hash1 = hash(JT)
|
714
|
+
hash2 = hash(JT)
|
715
|
+
self.assertEqual(hash1, hash2)
|
716
|
+
|
717
|
+
def test_hash_order_independent(self):
|
718
|
+
"""Test hash is order-independent."""
|
719
|
+
JT1 = brainstate.mixin.JointTypes[self.A, self.B]
|
720
|
+
JT2 = brainstate.mixin.JointTypes[self.B, self.A]
|
721
|
+
self.assertEqual(hash(JT1), hash(JT2))
|
722
|
+
|
723
|
+
def test_hash_different_for_different_types(self):
|
724
|
+
"""Test different types have different hashes."""
|
725
|
+
JT1 = brainstate.mixin.JointTypes[self.A, self.B]
|
726
|
+
JT2 = brainstate.mixin.JointTypes[self.A, self.C]
|
727
|
+
# Note: hash collision is possible but unlikely for different types
|
728
|
+
self.assertNotEqual(hash(JT1), hash(JT2))
|
729
|
+
|
730
|
+
def test_hashable_in_set(self):
|
731
|
+
"""Test JointTypes can be used in sets."""
|
732
|
+
JT1 = brainstate.mixin.JointTypes[self.A, self.B]
|
733
|
+
JT2 = brainstate.mixin.JointTypes[self.B, self.A]
|
734
|
+
JT3 = brainstate.mixin.JointTypes[self.A, self.C]
|
735
|
+
|
736
|
+
type_set = {JT1, JT2, JT3}
|
737
|
+
# JT1 and JT2 are equal, so set should have 2 elements
|
738
|
+
self.assertEqual(len(type_set), 2)
|
739
|
+
self.assertIn(JT1, type_set)
|
740
|
+
self.assertIn(JT2, type_set)
|
741
|
+
self.assertIn(JT3, type_set)
|
742
|
+
|
743
|
+
def test_as_dict_key(self):
|
744
|
+
"""Test JointTypes can be used as dict keys."""
|
745
|
+
JT1 = brainstate.mixin.JointTypes[self.A, self.B]
|
746
|
+
JT2 = brainstate.mixin.JointTypes[self.B, self.A]
|
747
|
+
|
748
|
+
type_dict = {JT1: "AB type"}
|
749
|
+
self.assertIn(JT1, type_dict)
|
750
|
+
# JT2 should work as key since it's equal to JT1
|
751
|
+
self.assertIn(JT2, type_dict)
|
752
|
+
self.assertEqual(type_dict[JT2], "AB type")
|
753
|
+
|
754
|
+
def test_pickle_roundtrip(self):
|
755
|
+
"""Test pickling and unpickling with built-in types."""
|
756
|
+
import pickle
|
757
|
+
# Use built-in types since local classes can't be pickled
|
758
|
+
JT = brainstate.mixin.JointTypes[int, str]
|
759
|
+
pickled = pickle.dumps(JT)
|
760
|
+
unpickled = pickle.loads(pickled)
|
761
|
+
self.assertEqual(JT, unpickled)
|
762
|
+
self.assertEqual(hash(JT), hash(unpickled))
|
763
|
+
|
764
|
+
def test_pickle_preserves_isinstance(self):
|
765
|
+
"""Test isinstance works after pickle with built-in types."""
|
766
|
+
import pickle
|
767
|
+
|
768
|
+
class IntStr(int):
|
769
|
+
"""A class that inherits from int."""
|
770
|
+
pass
|
771
|
+
|
772
|
+
# Use built-in types for pickling
|
773
|
+
JT = brainstate.mixin.JointTypes[int, object]
|
774
|
+
pickled = pickle.dumps(JT)
|
775
|
+
unpickled = pickle.loads(pickled)
|
776
|
+
|
777
|
+
obj = 42
|
778
|
+
self.assertTrue(isinstance(obj, JT))
|
779
|
+
self.assertTrue(isinstance(obj, unpickled))
|
780
|
+
|
781
|
+
def test_multiple_types(self):
|
782
|
+
"""Test JointTypes with more than 2 types."""
|
783
|
+
JT = brainstate.mixin.JointTypes[self.A, self.B, self.C]
|
784
|
+
|
785
|
+
class ABC(self.A, self.B, self.C):
|
786
|
+
pass
|
787
|
+
|
788
|
+
self.assertTrue(issubclass(ABC, JT))
|
789
|
+
|
790
|
+
class AB(self.A, self.B):
|
791
|
+
pass
|
792
|
+
|
793
|
+
self.assertFalse(issubclass(AB, JT))
|
794
|
+
|
795
|
+
def test_subscript_vs_call_syntax(self):
|
796
|
+
"""Test subscript and call syntax produce equal results."""
|
797
|
+
JT_subscript = brainstate.mixin.JointTypes[self.A, self.B]
|
798
|
+
JT_call = brainstate.mixin.JointTypes(self.A, self.B)
|
799
|
+
self.assertEqual(JT_subscript, JT_call)
|
800
|
+
self.assertEqual(hash(JT_subscript), hash(JT_call))
|
801
|
+
|
802
|
+
def test_args_attribute(self):
|
803
|
+
"""Test __args__ attribute contains correct types."""
|
804
|
+
JT = brainstate.mixin.JointTypes[self.A, self.B]
|
805
|
+
self.assertIn(self.A, JT.__args__)
|
806
|
+
self.assertIn(self.B, JT.__args__)
|
807
|
+
self.assertEqual(len(JT.__args__), 2)
|
808
|
+
|
809
|
+
|
810
|
+
class TestOneOfTypesComprehensive(unittest.TestCase):
|
811
|
+
"""Comprehensive tests for OneOfTypes special methods and functionality."""
|
812
|
+
|
813
|
+
def setUp(self):
|
814
|
+
"""Set up test classes."""
|
815
|
+
class A:
|
816
|
+
pass
|
817
|
+
|
818
|
+
class B:
|
819
|
+
pass
|
820
|
+
|
821
|
+
class C:
|
822
|
+
pass
|
823
|
+
|
824
|
+
self.A = A
|
825
|
+
self.B = B
|
826
|
+
self.C = C
|
827
|
+
|
828
|
+
def test_repr(self):
|
829
|
+
"""Test __repr__ method."""
|
830
|
+
OT = brainstate.mixin.OneOfTypes[self.A, self.B]
|
831
|
+
repr_str = repr(OT)
|
832
|
+
self.assertIn('OneOfTypes', repr_str)
|
833
|
+
self.assertIn('A', repr_str)
|
834
|
+
self.assertIn('B', repr_str)
|
835
|
+
|
836
|
+
def test_eq_same_order(self):
|
837
|
+
"""Test equality with same type order."""
|
838
|
+
OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
|
839
|
+
OT2 = brainstate.mixin.OneOfTypes[self.A, self.B]
|
840
|
+
self.assertEqual(OT1, OT2)
|
841
|
+
|
842
|
+
def test_eq_different_order(self):
|
843
|
+
"""Test equality with different type order."""
|
844
|
+
OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
|
845
|
+
OT2 = brainstate.mixin.OneOfTypes[self.B, self.A]
|
846
|
+
self.assertEqual(OT1, OT2)
|
847
|
+
|
848
|
+
def test_eq_different_types(self):
|
849
|
+
"""Test inequality with different types."""
|
850
|
+
OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
|
851
|
+
OT2 = brainstate.mixin.OneOfTypes[self.A, self.C]
|
852
|
+
self.assertNotEqual(OT1, OT2)
|
853
|
+
|
854
|
+
def test_eq_with_non_oneoftypes(self):
|
855
|
+
"""Test equality with non-OneOfTypes object."""
|
856
|
+
OT = brainstate.mixin.OneOfTypes[self.A, self.B]
|
857
|
+
self.assertNotEqual(OT, "not a type")
|
858
|
+
self.assertNotEqual(OT, 42)
|
859
|
+
self.assertNotEqual(OT, self.A)
|
860
|
+
|
861
|
+
def test_hash_consistency(self):
|
862
|
+
"""Test hash consistency."""
|
863
|
+
OT = brainstate.mixin.OneOfTypes[self.A, self.B]
|
864
|
+
hash1 = hash(OT)
|
865
|
+
hash2 = hash(OT)
|
866
|
+
self.assertEqual(hash1, hash2)
|
867
|
+
|
868
|
+
def test_hash_order_independent(self):
|
869
|
+
"""Test hash is order-independent."""
|
870
|
+
OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
|
871
|
+
OT2 = brainstate.mixin.OneOfTypes[self.B, self.A]
|
872
|
+
self.assertEqual(hash(OT1), hash(OT2))
|
873
|
+
|
874
|
+
def test_hash_different_for_different_types(self):
|
875
|
+
"""Test different types have different hashes."""
|
876
|
+
OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
|
877
|
+
OT2 = brainstate.mixin.OneOfTypes[self.A, self.C]
|
878
|
+
self.assertNotEqual(hash(OT1), hash(OT2))
|
879
|
+
|
880
|
+
def test_hashable_in_set(self):
|
881
|
+
"""Test OneOfTypes can be used in sets."""
|
882
|
+
OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
|
883
|
+
OT2 = brainstate.mixin.OneOfTypes[self.B, self.A]
|
884
|
+
OT3 = brainstate.mixin.OneOfTypes[self.A, self.C]
|
885
|
+
|
886
|
+
type_set = {OT1, OT2, OT3}
|
887
|
+
# OT1 and OT2 are equal, so set should have 2 elements
|
888
|
+
self.assertEqual(len(type_set), 2)
|
889
|
+
self.assertIn(OT1, type_set)
|
890
|
+
self.assertIn(OT2, type_set)
|
891
|
+
self.assertIn(OT3, type_set)
|
892
|
+
|
893
|
+
def test_as_dict_key(self):
|
894
|
+
"""Test OneOfTypes can be used as dict keys."""
|
895
|
+
OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
|
896
|
+
OT2 = brainstate.mixin.OneOfTypes[self.B, self.A]
|
897
|
+
|
898
|
+
type_dict = {OT1: "A or B type"}
|
899
|
+
self.assertIn(OT1, type_dict)
|
900
|
+
self.assertIn(OT2, type_dict)
|
901
|
+
self.assertEqual(type_dict[OT2], "A or B type")
|
902
|
+
|
903
|
+
def test_pickle_roundtrip(self):
|
904
|
+
"""Test pickling and unpickling with built-in types."""
|
905
|
+
import pickle
|
906
|
+
# Use built-in types since local classes can't be pickled
|
907
|
+
OT = brainstate.mixin.OneOfTypes[int, str]
|
908
|
+
pickled = pickle.dumps(OT)
|
909
|
+
unpickled = pickle.loads(pickled)
|
910
|
+
self.assertEqual(OT, unpickled)
|
911
|
+
self.assertEqual(hash(OT), hash(unpickled))
|
912
|
+
|
913
|
+
def test_pickle_preserves_isinstance(self):
|
914
|
+
"""Test isinstance works after pickle with built-in types."""
|
915
|
+
import pickle
|
916
|
+
# Use built-in types for pickling
|
917
|
+
OT = brainstate.mixin.OneOfTypes[int, str]
|
918
|
+
pickled = pickle.dumps(OT)
|
919
|
+
unpickled = pickle.loads(pickled)
|
920
|
+
|
921
|
+
obj_a = 42
|
922
|
+
obj_b = "hello"
|
923
|
+
|
924
|
+
self.assertTrue(isinstance(obj_a, OT))
|
925
|
+
self.assertTrue(isinstance(obj_a, unpickled))
|
926
|
+
self.assertTrue(isinstance(obj_b, OT))
|
927
|
+
self.assertTrue(isinstance(obj_b, unpickled))
|
928
|
+
|
929
|
+
def test_isinstance_with_any_type(self):
|
930
|
+
"""Test isinstance returns True if object is instance of any type."""
|
931
|
+
OT = brainstate.mixin.OneOfTypes[self.A, self.B]
|
932
|
+
|
933
|
+
obj_a = self.A()
|
934
|
+
obj_b = self.B()
|
935
|
+
obj_c = self.C()
|
936
|
+
|
937
|
+
self.assertTrue(isinstance(obj_a, OT))
|
938
|
+
self.assertTrue(isinstance(obj_b, OT))
|
939
|
+
self.assertFalse(isinstance(obj_c, OT))
|
940
|
+
|
941
|
+
def test_issubclass_with_any_type(self):
|
942
|
+
"""Test issubclass returns True if class is subclass of any type."""
|
943
|
+
OT = brainstate.mixin.OneOfTypes[self.A, self.B]
|
944
|
+
|
945
|
+
class SubA(self.A):
|
946
|
+
pass
|
947
|
+
|
948
|
+
class SubB(self.B):
|
949
|
+
pass
|
950
|
+
|
951
|
+
self.assertTrue(issubclass(SubA, OT))
|
952
|
+
self.assertTrue(issubclass(SubB, OT))
|
953
|
+
self.assertTrue(issubclass(self.A, OT))
|
954
|
+
self.assertTrue(issubclass(self.B, OT))
|
955
|
+
self.assertFalse(issubclass(self.C, OT))
|
956
|
+
|
957
|
+
def test_multiple_types(self):
|
958
|
+
"""Test OneOfTypes with more than 2 types."""
|
959
|
+
OT = brainstate.mixin.OneOfTypes[self.A, self.B, self.C]
|
960
|
+
|
961
|
+
obj_a = self.A()
|
962
|
+
obj_b = self.B()
|
963
|
+
obj_c = self.C()
|
964
|
+
|
965
|
+
self.assertTrue(isinstance(obj_a, OT))
|
966
|
+
self.assertTrue(isinstance(obj_b, OT))
|
967
|
+
self.assertTrue(isinstance(obj_c, OT))
|
968
|
+
|
969
|
+
def test_subscript_vs_call_syntax(self):
|
970
|
+
"""Test subscript and call syntax produce equal results."""
|
971
|
+
OT_subscript = brainstate.mixin.OneOfTypes[self.A, self.B]
|
972
|
+
OT_call = brainstate.mixin.OneOfTypes(self.A, self.B)
|
973
|
+
self.assertEqual(OT_subscript, OT_call)
|
974
|
+
self.assertEqual(hash(OT_subscript), hash(OT_call))
|
975
|
+
|
976
|
+
def test_args_attribute(self):
|
977
|
+
"""Test __args__ attribute contains correct types."""
|
978
|
+
OT = brainstate.mixin.OneOfTypes[self.A, self.B]
|
979
|
+
self.assertIn(self.A, OT.__args__)
|
980
|
+
self.assertIn(self.B, OT.__args__)
|
981
|
+
self.assertEqual(len(OT.__args__), 2)
|
982
|
+
|
983
|
+
def test_with_builtin_types(self):
|
984
|
+
"""Test OneOfTypes with built-in types."""
|
985
|
+
OT = brainstate.mixin.OneOfTypes[int, float, str]
|
986
|
+
|
987
|
+
self.assertTrue(isinstance(42, OT))
|
988
|
+
self.assertTrue(isinstance(3.14, OT))
|
989
|
+
self.assertTrue(isinstance("hello", OT))
|
990
|
+
self.assertFalse(isinstance([], OT))
|
991
|
+
|
992
|
+
|
993
|
+
class TestJointTy:
|
994
|
+
def test1(self):
|
995
|
+
class Potassium:
|
996
|
+
pass
|
997
|
+
|
998
|
+
class Calcium:
|
999
|
+
pass
|
1000
|
+
|
1001
|
+
# Test JointTypes
|
1002
|
+
result1 = brainstate.mixin.JointTypes(Potassium, Calcium)
|
1003
|
+
result2 = brainstate.mixin.JointTypes[Potassium, Calcium]
|
1004
|
+
print(f'Function call: {result1}')
|
1005
|
+
print(f'Subscript: {result2}')
|
1006
|
+
print(f'Same? {result1 == result2}')
|
1007
|
+
|
1008
|
+
# Test OneOfTypes
|
1009
|
+
result3 = brainstate.mixin.OneOfTypes(Potassium, Calcium)
|
1010
|
+
result4 = brainstate.mixin.OneOfTypes[Potassium, Calcium]
|
1011
|
+
print(f'\nOneOfTypes Function call: {result3}')
|
1012
|
+
print(f'OneOfTypes Subscript: {result4}')
|
1013
|
+
print(f'Same? {result3 == result4}')
|
1014
|
+
|
1015
|
+
|
1016
|
+
if __name__ == '__main__':
|
1017
|
+
unittest.main()
|