brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/mixin_test.py CHANGED
@@ -1,77 +1,1017 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import unittest
17
- import jax
18
- import jax.numpy as jnp
19
- import numpy as np
20
-
21
- import brainstate
22
-
23
-
24
- class TestMixin(unittest.TestCase):
25
- def test_mixin(self):
26
- self.assertTrue(brainstate.mixin.Mixin)
27
- self.assertTrue(brainstate.mixin.ParamDesc)
28
- self.assertTrue(brainstate.mixin.ParamDescriber)
29
- self.assertTrue(brainstate.mixin.JointTypes)
30
- self.assertTrue(brainstate.mixin.OneOfTypes)
31
- self.assertTrue(brainstate.mixin.Mode)
32
- self.assertTrue(brainstate.mixin.Batching)
33
- self.assertTrue(brainstate.mixin.Training)
34
-
35
-
36
- class TestMode(unittest.TestCase):
37
- def test_JointMode(self):
38
- a = brainstate.mixin.JointMode(brainstate.mixin.Batching(), brainstate.mixin.Training())
39
- self.assertTrue(a.is_a(brainstate.mixin.JointTypes[brainstate.mixin.Batching, brainstate.mixin.Training]))
40
- self.assertTrue(a.has(brainstate.mixin.Batching))
41
- self.assertTrue(a.has(brainstate.mixin.Training))
42
- b = brainstate.mixin.JointMode(brainstate.mixin.Batching())
43
- self.assertTrue(b.is_a(brainstate.mixin.JointTypes[brainstate.mixin.Batching]))
44
- self.assertTrue(b.is_a(brainstate.mixin.Batching))
45
- self.assertTrue(b.has(brainstate.mixin.Batching))
46
-
47
- def test_Training(self):
48
- a = brainstate.mixin.Training()
49
- self.assertTrue(a.is_a(brainstate.mixin.Training))
50
- self.assertTrue(a.is_a(brainstate.mixin.JointTypes[brainstate.mixin.Training]))
51
- self.assertTrue(a.has(brainstate.mixin.Training))
52
- self.assertTrue(a.has(brainstate.mixin.JointTypes[brainstate.mixin.Training]))
53
- self.assertFalse(a.is_a(brainstate.mixin.Batching))
54
- self.assertFalse(a.has(brainstate.mixin.Batching))
55
-
56
- def test_Batching(self):
57
- a = brainstate.mixin.Batching()
58
- self.assertTrue(a.is_a(brainstate.mixin.Batching))
59
- self.assertTrue(a.is_a(brainstate.mixin.JointTypes[brainstate.mixin.Batching]))
60
- self.assertTrue(a.has(brainstate.mixin.Batching))
61
- self.assertTrue(a.has(brainstate.mixin.JointTypes[brainstate.mixin.Batching]))
62
-
63
- self.assertFalse(a.is_a(brainstate.mixin.Training))
64
- self.assertFalse(a.has(brainstate.mixin.Training))
65
-
66
- def test_Mode(self):
67
- a = brainstate.mixin.Mode()
68
- self.assertTrue(a.is_a(brainstate.mixin.Mode))
69
- self.assertTrue(a.is_a(brainstate.mixin.JointTypes[brainstate.mixin.Mode]))
70
- self.assertTrue(a.has(brainstate.mixin.Mode))
71
- self.assertTrue(a.has(brainstate.mixin.JointTypes[brainstate.mixin.Mode]))
72
-
73
- self.assertFalse(a.is_a(brainstate.mixin.Training))
74
- self.assertFalse(a.has(brainstate.mixin.Training))
75
- self.assertFalse(a.is_a(brainstate.mixin.Batching))
76
- self.assertFalse(a.has(brainstate.mixin.Batching))
77
-
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Comprehensive tests for brainstate.mixin module.
18
+
19
+ This test suite covers all functionality in the mixin module including:
20
+ - Base mixin classes
21
+ - Parameter description and deferred instantiation
22
+ - Type utilities (JointTypes, OneOfTypes)
23
+ - Mode system (Mode, JointMode, Training, Batching)
24
+ - Helper utilities (hashable, not_implemented, etc.)
25
+ """
26
+
27
+ import unittest
28
+
29
+ import jax.numpy as jnp
30
+
31
+ import brainstate
32
+
33
+
34
+ class TestHashableFunction(unittest.TestCase):
35
+ """Test the hashable utility function."""
36
+
37
+ def test_hashable_primitives(self):
38
+ """Test hashable with primitive types."""
39
+ self.assertTrue(brainstate.mixin.hashable(42))
40
+ self.assertTrue(brainstate.mixin.hashable(3.14))
41
+ self.assertTrue(brainstate.mixin.hashable("string"))
42
+ self.assertTrue(brainstate.mixin.hashable(True))
43
+ self.assertTrue(brainstate.mixin.hashable(None))
44
+
45
+ def test_hashable_tuples(self):
46
+ """Test hashable with tuples."""
47
+ self.assertTrue(brainstate.mixin.hashable((1, 2, 3)))
48
+ self.assertTrue(brainstate.mixin.hashable(("a", "b")))
49
+ self.assertTrue(brainstate.mixin.hashable(()))
50
+
51
+ def test_non_hashable_types(self):
52
+ """Test non-hashable types."""
53
+ self.assertFalse(brainstate.mixin.hashable([1, 2, 3]))
54
+ self.assertFalse(brainstate.mixin.hashable({"key": "value"}))
55
+ self.assertFalse(brainstate.mixin.hashable({1, 2, 3}))
56
+ self.assertFalse(brainstate.mixin.hashable(jnp.array([1, 2, 3])))
57
+
58
+
59
+ class TestMixin(unittest.TestCase):
60
+ """Test the base Mixin class."""
61
+
62
+ def test_mixin_exists(self):
63
+ """Test that Mixin class exists."""
64
+ self.assertTrue(brainstate.mixin.Mixin)
65
+
66
+ def test_mixin_inheritance(self):
67
+ """Test creating a custom mixin."""
68
+
69
+ class LoggingMixin(brainstate.mixin.Mixin):
70
+ def log(self, message):
71
+ return f"[LOG] {message}"
72
+
73
+ class Component(LoggingMixin):
74
+ pass
75
+
76
+ comp = Component()
77
+ self.assertEqual(comp.log("test"), "[LOG] test")
78
+
79
+ def test_mixin_multiple_inheritance(self):
80
+ """Test multiple mixin inheritance."""
81
+
82
+ class MixinA(brainstate.mixin.Mixin):
83
+ def method_a(self):
84
+ return "A"
85
+
86
+ class MixinB(brainstate.mixin.Mixin):
87
+ def method_b(self):
88
+ return "B"
89
+
90
+ class Component(MixinA, MixinB):
91
+ pass
92
+
93
+ comp = Component()
94
+ self.assertEqual(comp.method_a(), "A")
95
+ self.assertEqual(comp.method_b(), "B")
96
+
97
+
98
+ class TestParamDesc(unittest.TestCase):
99
+ """Test ParamDesc mixin and ParamDescriber."""
100
+
101
+ def test_param_desc_basic(self):
102
+ """Test basic ParamDesc functionality."""
103
+
104
+ class Network(brainstate.mixin.ParamDesc):
105
+ def __init__(self, size, learning_rate=0.01):
106
+ self.size = size
107
+ self.learning_rate = learning_rate
108
+
109
+ # Test desc method exists
110
+ self.assertTrue(hasattr(Network, 'desc'))
111
+
112
+ # Create a descriptor
113
+ desc = Network.desc(size=100)
114
+ self.assertIsInstance(desc, brainstate.mixin.ParamDescriber)
115
+
116
+ def test_param_describer_instantiation(self):
117
+ """Test ParamDescriber can create instances."""
118
+
119
+ class Network(brainstate.mixin.ParamDesc):
120
+ def __init__(self, size, learning_rate=0.01):
121
+ self.size = size
122
+ self.learning_rate = learning_rate
123
+
124
+ desc = Network.desc(size=100, learning_rate=0.001)
125
+
126
+ # Create instances
127
+ net1 = desc()
128
+ self.assertEqual(net1.size, 100)
129
+ self.assertEqual(net1.learning_rate, 0.001)
130
+
131
+ # Create with overrides
132
+ net2 = desc(learning_rate=0.005)
133
+ self.assertEqual(net2.size, 100)
134
+ self.assertEqual(net2.learning_rate, 0.005)
135
+
136
+ def test_param_describer_init_method(self):
137
+ """Test ParamDescriber.init() method."""
138
+
139
+ class Model(brainstate.mixin.ParamDesc):
140
+ def __init__(self, value):
141
+ self.value = value
142
+
143
+ desc = Model.desc(value=42)
144
+ instance = desc.init()
145
+ self.assertEqual(instance.value, 42)
146
+
147
+ def test_param_describer_identifier(self):
148
+ """Test ParamDescriber identifier property."""
149
+
150
+ class Model(brainstate.mixin.ParamDesc):
151
+ def __init__(self, x, y=10):
152
+ self.x = x
153
+ self.y = y
154
+
155
+ desc = Model.desc(x=5, y=20)
156
+ identifier = desc.identifier
157
+
158
+ # Identifier should be a tuple
159
+ self.assertIsInstance(identifier, tuple)
160
+ self.assertEqual(len(identifier), 3)
161
+ self.assertEqual(identifier[0], Model)
162
+
163
+ # Identifier should be read-only
164
+ with self.assertRaises(AttributeError):
165
+ desc.identifier = "new"
166
+
167
+ def test_param_describer_class_getitem(self):
168
+ """Test ParamDescriber[Class] notation."""
169
+
170
+ class Model:
171
+ def __init__(self, value):
172
+ self.value = value
173
+
174
+ desc = brainstate.mixin.ParamDescriber[Model]
175
+ self.assertIsInstance(desc, brainstate.mixin.ParamDescriber)
176
+ self.assertEqual(desc.cls, Model)
177
+
178
+ def test_no_subclass_meta(self):
179
+ """Test that ParamDescriber cannot be subclassed."""
180
+
181
+ with self.assertRaises(TypeError):
182
+ class CustomDescriber(brainstate.mixin.ParamDescriber):
183
+ pass
184
+
185
+
186
+ class TestHashableDict(unittest.TestCase):
187
+ """Test HashableDict class."""
188
+
189
+ def test_hashable_dict_basic(self):
190
+ """Test basic HashableDict functionality."""
191
+ d = brainstate.mixin.HashableDict({"a": 1, "b": 2})
192
+ h = hash(d)
193
+ self.assertIsInstance(h, int)
194
+
195
+ def test_hashable_dict_with_arrays(self):
196
+ """Test HashableDict with non-hashable values."""
197
+ d = brainstate.mixin.HashableDict({
198
+ "array": jnp.array([1, 2, 3]),
199
+ "value": 42
200
+ })
201
+ h = hash(d)
202
+ self.assertIsInstance(h, int)
203
+
204
+ def test_hashable_dict_consistency(self):
205
+ """Test that equal dicts have equal hashes."""
206
+ d1 = brainstate.mixin.HashableDict({"a": 1, "b": 2})
207
+ d2 = brainstate.mixin.HashableDict({"b": 2, "a": 1})
208
+ self.assertEqual(hash(d1), hash(d2))
209
+
210
+ def test_hashable_dict_usable_as_key(self):
211
+ """Test that HashableDict can be used as dict key."""
212
+ d = brainstate.mixin.HashableDict({"x": 10})
213
+ cache = {d: "result"}
214
+ self.assertEqual(cache[d], "result")
215
+
216
+
217
+ class TestJointTypes(unittest.TestCase):
218
+ """Test JointTypes functionality."""
219
+
220
+ def test_joint_types_basic(self):
221
+ """Test basic JointTypes creation."""
222
+
223
+ class A:
224
+ pass
225
+
226
+ class B:
227
+ pass
228
+
229
+ JointAB = brainstate.mixin.JointTypes(A, B)
230
+ self.assertIsNotNone(JointAB)
231
+
232
+ def test_joint_types_isinstance(self):
233
+ """Test isinstance with JointTypes."""
234
+
235
+ class Serializable:
236
+ def save(self):
237
+ pass
238
+
239
+ class Visualizable:
240
+ def plot(self):
241
+ pass
242
+
243
+ Combined = brainstate.mixin.JointTypes(Serializable, Visualizable)
244
+
245
+ class Model(Serializable, Visualizable):
246
+ def save(self):
247
+ return "saved"
248
+
249
+ def plot(self):
250
+ return "plotted"
251
+
252
+ model = Model()
253
+ self.assertTrue(isinstance(model, Combined))
254
+
255
+ def test_joint_types_issubclass(self):
256
+ """Test issubclass with JointTypes."""
257
+
258
+ class A:
259
+ pass
260
+
261
+ class B:
262
+ pass
263
+
264
+ JointAB = brainstate.mixin.JointTypes(A, B)
265
+
266
+ class C(A, B):
267
+ pass
268
+
269
+ self.assertTrue(issubclass(C, JointAB))
270
+
271
+ def test_joint_types_single_type(self):
272
+ """Test JointTypes with single type returns that type."""
273
+
274
+ class A:
275
+ pass
276
+
277
+ result = brainstate.mixin.JointTypes(A)
278
+ self.assertEqual(result, A)
279
+
280
+ def test_joint_types_no_types(self):
281
+ """Test JointTypes with no types raises error."""
282
+ with self.assertRaises(TypeError):
283
+ brainstate.mixin.JointTypes()
284
+
285
+ def test_joint_types_removes_duplicates(self):
286
+ """Test that JointTypes removes duplicate types."""
287
+
288
+ class A:
289
+ pass
290
+
291
+ # Should handle duplicates gracefully
292
+ JointA = brainstate.mixin.JointTypes(A, A, A)
293
+ self.assertEqual(JointA, A)
294
+
295
+
296
+ class TestOneOfTypes(unittest.TestCase):
297
+ """Test OneOfTypes functionality."""
298
+
299
+ def test_one_of_types_basic(self):
300
+ """Test basic OneOfTypes creation."""
301
+ IntOrFloat = brainstate.mixin.OneOfTypes(int, float)
302
+ self.assertIsNotNone(IntOrFloat)
303
+
304
+ def test_one_of_types_isinstance(self):
305
+ """Test isinstance with OneOfTypes."""
306
+ NumType = brainstate.mixin.OneOfTypes(int, float)
307
+
308
+ self.assertTrue(isinstance(42, NumType))
309
+ self.assertTrue(isinstance(3.14, NumType))
310
+ self.assertFalse(isinstance("hello", NumType))
311
+
312
+ def test_one_of_types_single_type(self):
313
+ """Test OneOfTypes with single type returns that type."""
314
+ result = brainstate.mixin.OneOfTypes(int)
315
+ self.assertEqual(result, int)
316
+
317
+ def test_one_of_types_no_types(self):
318
+ """Test OneOfTypes with no types raises error."""
319
+ with self.assertRaises(TypeError):
320
+ brainstate.mixin.OneOfTypes()
321
+
322
+ def test_one_of_types_with_none(self):
323
+ """Test OneOfTypes with None for optional types."""
324
+ MaybeInt = brainstate.mixin.OneOfTypes(int, type(None))
325
+
326
+ self.assertTrue(isinstance(42, MaybeInt))
327
+ self.assertTrue(isinstance(None, MaybeInt))
328
+ self.assertFalse(isinstance("hello", MaybeInt))
329
+
330
+
331
+
332
+ class TestNotImplemented(unittest.TestCase):
333
+ """Test not_implemented decorator."""
334
+
335
+ def test_not_implemented_decorator(self):
336
+ """Test not_implemented decorator marks functions."""
337
+
338
+ @brainstate.mixin.not_implemented
339
+ def my_function():
340
+ pass
341
+
342
+ self.assertTrue(hasattr(my_function, 'not_implemented'))
343
+ self.assertTrue(my_function.not_implemented)
344
+
345
+ def test_not_implemented_raises(self):
346
+ """Test not_implemented decorator raises error when called."""
347
+
348
+ @brainstate.mixin.not_implemented
349
+ def my_function():
350
+ pass
351
+
352
+ with self.assertRaises(NotImplementedError) as cm:
353
+ my_function()
354
+
355
+ self.assertIn("my_function", str(cm.exception))
356
+
357
+
358
+ class TestMode(unittest.TestCase):
359
+ """Test Mode base class."""
360
+
361
+ def test_mode_creation(self):
362
+ """Test basic Mode creation."""
363
+ mode = brainstate.mixin.Mode()
364
+ self.assertIsNotNone(mode)
365
+
366
+ def test_mode_repr(self):
367
+ """Test Mode string representation."""
368
+ mode = brainstate.mixin.Mode()
369
+ self.assertEqual(repr(mode), "Mode")
370
+
371
+ def test_mode_equality(self):
372
+ """Test Mode equality comparison."""
373
+ mode1 = brainstate.mixin.Mode()
374
+ mode2 = brainstate.mixin.Mode()
375
+ self.assertEqual(mode1, mode2)
376
+
377
+ def test_mode_is_a(self):
378
+ """Test Mode.is_a() method."""
379
+ mode = brainstate.mixin.Mode()
380
+ self.assertTrue(mode.is_a(brainstate.mixin.Mode))
381
+ self.assertFalse(mode.is_a(brainstate.mixin.Training))
382
+
383
+ def test_mode_has(self):
384
+ """Test Mode.has() method."""
385
+ mode = brainstate.mixin.Mode()
386
+ self.assertTrue(mode.has(brainstate.mixin.Mode))
387
+ self.assertFalse(mode.has(brainstate.mixin.Training))
388
+
389
+ def test_custom_mode(self):
390
+ """Test creating custom mode."""
391
+
392
+ class CustomMode(brainstate.mixin.Mode):
393
+ def __init__(self, value):
394
+ self.value = value
395
+
396
+ mode = CustomMode(42)
397
+ self.assertEqual(mode.value, 42)
398
+ self.assertTrue(mode.has(brainstate.mixin.Mode))
399
+
400
+
401
+ class TestTraining(unittest.TestCase):
402
+ """Test Training mode."""
403
+
404
+ def test_training_creation(self):
405
+ """Test Training mode creation."""
406
+ training = brainstate.mixin.Training()
407
+ self.assertIsNotNone(training)
408
+
409
+ def test_training_is_mode(self):
410
+ """Test Training is a Mode."""
411
+ training = brainstate.mixin.Training()
412
+ self.assertTrue(training.has(brainstate.mixin.Mode))
413
+
414
+ def test_training_is_a(self):
415
+ """Test Training.is_a() method."""
416
+ training = brainstate.mixin.Training()
417
+ self.assertTrue(training.is_a(brainstate.mixin.Training))
418
+ self.assertFalse(training.is_a(brainstate.mixin.Batching))
419
+
420
+ def test_training_has(self):
421
+ """Test Training.has() method."""
422
+ training = brainstate.mixin.Training()
423
+ self.assertTrue(training.has(brainstate.mixin.Training))
424
+ self.assertFalse(training.has(brainstate.mixin.Batching))
425
+
426
+ def test_training_joint_types(self):
427
+ """Test Training with JointTypes."""
428
+ training = brainstate.mixin.Training()
429
+ self.assertTrue(training.is_a(brainstate.mixin.JointTypes(brainstate.mixin.Training)))
430
+ self.assertTrue(training.has(brainstate.mixin.JointTypes(brainstate.mixin.Training)))
431
+
432
+
433
+ class TestBatching(unittest.TestCase):
434
+ """Test Batching mode."""
435
+
436
+ def test_batching_creation(self):
437
+ """Test Batching mode creation."""
438
+ batching = brainstate.mixin.Batching()
439
+ self.assertIsNotNone(batching)
440
+
441
+ def test_batching_default_params(self):
442
+ """Test Batching default parameters."""
443
+ batching = brainstate.mixin.Batching()
444
+ self.assertEqual(batching.batch_size, 1)
445
+ self.assertEqual(batching.batch_axis, 0)
446
+
447
+ def test_batching_custom_params(self):
448
+ """Test Batching with custom parameters."""
449
+ batching = brainstate.mixin.Batching(batch_size=32, batch_axis=1)
450
+ self.assertEqual(batching.batch_size, 32)
451
+ self.assertEqual(batching.batch_axis, 1)
452
+
453
+ def test_batching_repr(self):
454
+ """Test Batching string representation."""
455
+ batching = brainstate.mixin.Batching(batch_size=64, batch_axis=0)
456
+ self.assertIn("64", repr(batching))
457
+ self.assertIn("0", repr(batching))
458
+
459
+ def test_batching_is_mode(self):
460
+ """Test Batching is a Mode."""
461
+ batching = brainstate.mixin.Batching()
462
+ self.assertTrue(batching.has(brainstate.mixin.Mode))
463
+
464
+ def test_batching_is_a(self):
465
+ """Test Batching.is_a() method."""
466
+ batching = brainstate.mixin.Batching()
467
+ self.assertTrue(batching.is_a(brainstate.mixin.Batching))
468
+ self.assertFalse(batching.is_a(brainstate.mixin.Training))
469
+
470
+ def test_batching_has(self):
471
+ """Test Batching.has() method."""
472
+ batching = brainstate.mixin.Batching()
473
+ self.assertTrue(batching.has(brainstate.mixin.Batching))
474
+ self.assertFalse(batching.has(brainstate.mixin.Training))
475
+
476
+
477
+ class TestJointMode(unittest.TestCase):
478
+ """Test JointMode functionality."""
479
+
480
+ def test_joint_mode_creation(self):
481
+ """Test JointMode creation."""
482
+ training = brainstate.mixin.Training()
483
+ batching = brainstate.mixin.Batching()
484
+ joint = brainstate.mixin.JointMode(training, batching)
485
+ self.assertIsNotNone(joint)
486
+
487
+ def test_joint_mode_repr(self):
488
+ """Test JointMode string representation."""
489
+ training = brainstate.mixin.Training()
490
+ batching = brainstate.mixin.Batching(batch_size=32)
491
+ joint = brainstate.mixin.JointMode(training, batching)
492
+
493
+ repr_str = repr(joint)
494
+ self.assertIn("JointMode", repr_str)
495
+ self.assertIn("Training", repr_str)
496
+ self.assertIn("Batching", repr_str)
497
+
498
+ def test_joint_mode_has(self):
499
+ """Test JointMode.has() method."""
500
+ training = brainstate.mixin.Training()
501
+ batching = brainstate.mixin.Batching()
502
+ joint = brainstate.mixin.JointMode(training, batching)
503
+
504
+ self.assertTrue(joint.has(brainstate.mixin.Training))
505
+ self.assertTrue(joint.has(brainstate.mixin.Batching))
506
+ self.assertTrue(joint.has(brainstate.mixin.Mode))
507
+
508
+ def test_joint_mode_is_a(self):
509
+ """Test JointMode.is_a() method."""
510
+ training = brainstate.mixin.Training()
511
+ batching = brainstate.mixin.Batching()
512
+ joint = brainstate.mixin.JointMode(training, batching)
513
+
514
+ # JointMode.is_a() works by checking if the JointTypes of the mode types
515
+ # matches the expected type. This is a complex comparison.
516
+ # For practical use, test that it correctly identifies single types
517
+ self.assertFalse(joint.is_a(brainstate.mixin.Training)) # Not just Training
518
+ self.assertFalse(joint.is_a(brainstate.mixin.Batching)) # Not just Batching
519
+
520
+ # But a single mode joint should match
521
+ single_joint = brainstate.mixin.JointMode(training)
522
+ self.assertTrue(single_joint.is_a(brainstate.mixin.Training))
523
+
524
+ def test_joint_mode_single_mode(self):
525
+ """Test JointMode with single mode."""
526
+ batching = brainstate.mixin.Batching()
527
+ joint = brainstate.mixin.JointMode(batching)
528
+
529
+ self.assertTrue(joint.has(brainstate.mixin.Batching))
530
+ self.assertTrue(joint.is_a(brainstate.mixin.Batching))
531
+
532
+ def test_joint_mode_attribute_access(self):
533
+ """Test JointMode attribute delegation."""
534
+ batching = brainstate.mixin.Batching(batch_size=32, batch_axis=1)
535
+ training = brainstate.mixin.Training()
536
+ joint = brainstate.mixin.JointMode(batching, training)
537
+
538
+ # Should access batching attributes
539
+ self.assertEqual(joint.batch_size, 32)
540
+ self.assertEqual(joint.batch_axis, 1)
541
+
542
+ def test_joint_mode_invalid_type(self):
543
+ """Test JointMode with non-Mode raises error."""
544
+ with self.assertRaises(TypeError):
545
+ brainstate.mixin.JointMode("not a mode")
546
+
547
+ def test_joint_mode_modes_attribute(self):
548
+ """Test accessing modes attribute."""
549
+ training = brainstate.mixin.Training()
550
+ batching = brainstate.mixin.Batching()
551
+ joint = brainstate.mixin.JointMode(training, batching)
552
+
553
+ self.assertEqual(len(joint.modes), 2)
554
+ self.assertIn(training, joint.modes)
555
+ self.assertIn(batching, joint.modes)
556
+
557
+ def test_joint_mode_types_attribute(self):
558
+ """Test accessing types attribute."""
559
+ training = brainstate.mixin.Training()
560
+ batching = brainstate.mixin.Batching()
561
+ joint = brainstate.mixin.JointMode(training, batching)
562
+
563
+ self.assertEqual(len(joint.types), 2)
564
+ self.assertIn(brainstate.mixin.Training, joint.types)
565
+ self.assertIn(brainstate.mixin.Batching, joint.types)
566
+
567
+
568
+ class TestIntegration(unittest.TestCase):
569
+ """Integration tests combining multiple features."""
570
+
571
+ def test_param_desc_with_modes(self):
572
+ """Test ParamDesc with Mode system."""
573
+
574
+ class Model(brainstate.mixin.ParamDesc):
575
+ def __init__(self, size, mode=None):
576
+ self.size = size
577
+ self.mode = mode if mode is not None else brainstate.mixin.Mode()
578
+
579
+ # Create descriptor with training mode
580
+ train_model_desc = Model.desc(size=100, mode=brainstate.mixin.Training())
581
+ model = train_model_desc()
582
+
583
+ self.assertEqual(model.size, 100)
584
+ self.assertTrue(model.mode.has(brainstate.mixin.Training))
585
+
586
+ def test_joint_types_with_multiple_mixins(self):
587
+ """Test JointTypes with multiple mixin classes."""
588
+
589
+ class Serializable(brainstate.mixin.Mixin):
590
+ def save(self):
591
+ return "saved"
592
+
593
+ class Trainable(brainstate.mixin.Mixin):
594
+ def train(self):
595
+ return "trained"
596
+
597
+ class Evaluable(brainstate.mixin.Mixin):
598
+ def evaluate(self):
599
+ return "evaluated"
600
+
601
+ FullModel = brainstate.mixin.JointTypes(Serializable, Trainable, Evaluable)
602
+
603
+ class MyModel(Serializable, Trainable, Evaluable):
604
+ pass
605
+
606
+ model = MyModel()
607
+ self.assertTrue(isinstance(model, FullModel))
608
+ self.assertEqual(model.save(), "saved")
609
+ self.assertEqual(model.train(), "trained")
610
+ self.assertEqual(model.evaluate(), "evaluated")
611
+
612
+ def test_complex_mode_scenario(self):
613
+ """Test complex scenario with multiple modes."""
614
+
615
+ class NeuralNetwork:
616
+ def __init__(self):
617
+ self.mode = None
618
+
619
+ def set_mode(self, mode):
620
+ self.mode = mode
621
+
622
+ def forward(self, x):
623
+ if self.mode is None:
624
+ return x
625
+
626
+ if self.mode.has(brainstate.mixin.Training):
627
+ # Add noise during training
628
+ x = x + 0.1
629
+
630
+ if self.mode.has(brainstate.mixin.Batching):
631
+ # Process in batches
632
+ batch_size = self.mode.batch_size
633
+ # Just return with batch info for testing
634
+ return x, batch_size
635
+
636
+ return x
637
+
638
+ net = NeuralNetwork()
639
+
640
+ # Test evaluation mode
641
+ result = net.forward(1.0)
642
+ self.assertEqual(result, 1.0)
643
+
644
+ # Test training mode
645
+ net.set_mode(brainstate.mixin.Training())
646
+ result = net.forward(1.0)
647
+ self.assertAlmostEqual(result, 1.1)
648
+
649
+ # Test joint mode
650
+ training = brainstate.mixin.Training()
651
+ batching = brainstate.mixin.Batching(batch_size=32)
652
+ net.set_mode(brainstate.mixin.JointMode(training, batching))
653
+
654
+ result, batch_size = net.forward(1.0)
655
+ self.assertAlmostEqual(result, 1.1)
656
+ self.assertEqual(batch_size, 32)
657
+
658
+
659
+ class TestJointTypesComprehensive(unittest.TestCase):
660
+ """Comprehensive tests for JointTypes special methods and functionality."""
661
+
662
+ def setUp(self):
663
+ """Set up test classes."""
664
+ class A:
665
+ pass
666
+
667
+ class B:
668
+ pass
669
+
670
+ class C:
671
+ pass
672
+
673
+ self.A = A
674
+ self.B = B
675
+ self.C = C
676
+
677
+ def test_repr(self):
678
+ """Test __repr__ method."""
679
+ JT = brainstate.mixin.JointTypes[self.A, self.B]
680
+ repr_str = repr(JT)
681
+ self.assertIn('JointTypes', repr_str)
682
+ self.assertIn('A', repr_str)
683
+ self.assertIn('B', repr_str)
684
+
685
+ def test_eq_same_order(self):
686
+ """Test equality with same type order."""
687
+ JT1 = brainstate.mixin.JointTypes[self.A, self.B]
688
+ JT2 = brainstate.mixin.JointTypes[self.A, self.B]
689
+ self.assertEqual(JT1, JT2)
690
+
691
+ def test_eq_different_order(self):
692
+ """Test equality with different type order."""
693
+ JT1 = brainstate.mixin.JointTypes[self.A, self.B]
694
+ JT2 = brainstate.mixin.JointTypes[self.B, self.A]
695
+ self.assertEqual(JT1, JT2)
696
+
697
+ def test_eq_different_types(self):
698
+ """Test inequality with different types."""
699
+ JT1 = brainstate.mixin.JointTypes[self.A, self.B]
700
+ JT2 = brainstate.mixin.JointTypes[self.A, self.C]
701
+ self.assertNotEqual(JT1, JT2)
702
+
703
+ def test_eq_with_non_jointtypes(self):
704
+ """Test equality with non-JointTypes object."""
705
+ JT = brainstate.mixin.JointTypes[self.A, self.B]
706
+ self.assertNotEqual(JT, "not a type")
707
+ self.assertNotEqual(JT, 42)
708
+ self.assertNotEqual(JT, self.A)
709
+
710
+ def test_hash_consistency(self):
711
+ """Test hash consistency."""
712
+ JT = brainstate.mixin.JointTypes[self.A, self.B]
713
+ hash1 = hash(JT)
714
+ hash2 = hash(JT)
715
+ self.assertEqual(hash1, hash2)
716
+
717
+ def test_hash_order_independent(self):
718
+ """Test hash is order-independent."""
719
+ JT1 = brainstate.mixin.JointTypes[self.A, self.B]
720
+ JT2 = brainstate.mixin.JointTypes[self.B, self.A]
721
+ self.assertEqual(hash(JT1), hash(JT2))
722
+
723
+ def test_hash_different_for_different_types(self):
724
+ """Test different types have different hashes."""
725
+ JT1 = brainstate.mixin.JointTypes[self.A, self.B]
726
+ JT2 = brainstate.mixin.JointTypes[self.A, self.C]
727
+ # Note: hash collision is possible but unlikely for different types
728
+ self.assertNotEqual(hash(JT1), hash(JT2))
729
+
730
+ def test_hashable_in_set(self):
731
+ """Test JointTypes can be used in sets."""
732
+ JT1 = brainstate.mixin.JointTypes[self.A, self.B]
733
+ JT2 = brainstate.mixin.JointTypes[self.B, self.A]
734
+ JT3 = brainstate.mixin.JointTypes[self.A, self.C]
735
+
736
+ type_set = {JT1, JT2, JT3}
737
+ # JT1 and JT2 are equal, so set should have 2 elements
738
+ self.assertEqual(len(type_set), 2)
739
+ self.assertIn(JT1, type_set)
740
+ self.assertIn(JT2, type_set)
741
+ self.assertIn(JT3, type_set)
742
+
743
+ def test_as_dict_key(self):
744
+ """Test JointTypes can be used as dict keys."""
745
+ JT1 = brainstate.mixin.JointTypes[self.A, self.B]
746
+ JT2 = brainstate.mixin.JointTypes[self.B, self.A]
747
+
748
+ type_dict = {JT1: "AB type"}
749
+ self.assertIn(JT1, type_dict)
750
+ # JT2 should work as key since it's equal to JT1
751
+ self.assertIn(JT2, type_dict)
752
+ self.assertEqual(type_dict[JT2], "AB type")
753
+
754
+ def test_pickle_roundtrip(self):
755
+ """Test pickling and unpickling with built-in types."""
756
+ import pickle
757
+ # Use built-in types since local classes can't be pickled
758
+ JT = brainstate.mixin.JointTypes[int, str]
759
+ pickled = pickle.dumps(JT)
760
+ unpickled = pickle.loads(pickled)
761
+ self.assertEqual(JT, unpickled)
762
+ self.assertEqual(hash(JT), hash(unpickled))
763
+
764
+ def test_pickle_preserves_isinstance(self):
765
+ """Test isinstance works after pickle with built-in types."""
766
+ import pickle
767
+
768
+ class IntStr(int):
769
+ """A class that inherits from int."""
770
+ pass
771
+
772
+ # Use built-in types for pickling
773
+ JT = brainstate.mixin.JointTypes[int, object]
774
+ pickled = pickle.dumps(JT)
775
+ unpickled = pickle.loads(pickled)
776
+
777
+ obj = 42
778
+ self.assertTrue(isinstance(obj, JT))
779
+ self.assertTrue(isinstance(obj, unpickled))
780
+
781
+ def test_multiple_types(self):
782
+ """Test JointTypes with more than 2 types."""
783
+ JT = brainstate.mixin.JointTypes[self.A, self.B, self.C]
784
+
785
+ class ABC(self.A, self.B, self.C):
786
+ pass
787
+
788
+ self.assertTrue(issubclass(ABC, JT))
789
+
790
+ class AB(self.A, self.B):
791
+ pass
792
+
793
+ self.assertFalse(issubclass(AB, JT))
794
+
795
+ def test_subscript_vs_call_syntax(self):
796
+ """Test subscript and call syntax produce equal results."""
797
+ JT_subscript = brainstate.mixin.JointTypes[self.A, self.B]
798
+ JT_call = brainstate.mixin.JointTypes(self.A, self.B)
799
+ self.assertEqual(JT_subscript, JT_call)
800
+ self.assertEqual(hash(JT_subscript), hash(JT_call))
801
+
802
+ def test_args_attribute(self):
803
+ """Test __args__ attribute contains correct types."""
804
+ JT = brainstate.mixin.JointTypes[self.A, self.B]
805
+ self.assertIn(self.A, JT.__args__)
806
+ self.assertIn(self.B, JT.__args__)
807
+ self.assertEqual(len(JT.__args__), 2)
808
+
809
+
810
+ class TestOneOfTypesComprehensive(unittest.TestCase):
811
+ """Comprehensive tests for OneOfTypes special methods and functionality."""
812
+
813
+ def setUp(self):
814
+ """Set up test classes."""
815
+ class A:
816
+ pass
817
+
818
+ class B:
819
+ pass
820
+
821
+ class C:
822
+ pass
823
+
824
+ self.A = A
825
+ self.B = B
826
+ self.C = C
827
+
828
+ def test_repr(self):
829
+ """Test __repr__ method."""
830
+ OT = brainstate.mixin.OneOfTypes[self.A, self.B]
831
+ repr_str = repr(OT)
832
+ self.assertIn('OneOfTypes', repr_str)
833
+ self.assertIn('A', repr_str)
834
+ self.assertIn('B', repr_str)
835
+
836
+ def test_eq_same_order(self):
837
+ """Test equality with same type order."""
838
+ OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
839
+ OT2 = brainstate.mixin.OneOfTypes[self.A, self.B]
840
+ self.assertEqual(OT1, OT2)
841
+
842
+ def test_eq_different_order(self):
843
+ """Test equality with different type order."""
844
+ OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
845
+ OT2 = brainstate.mixin.OneOfTypes[self.B, self.A]
846
+ self.assertEqual(OT1, OT2)
847
+
848
+ def test_eq_different_types(self):
849
+ """Test inequality with different types."""
850
+ OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
851
+ OT2 = brainstate.mixin.OneOfTypes[self.A, self.C]
852
+ self.assertNotEqual(OT1, OT2)
853
+
854
+ def test_eq_with_non_oneoftypes(self):
855
+ """Test equality with non-OneOfTypes object."""
856
+ OT = brainstate.mixin.OneOfTypes[self.A, self.B]
857
+ self.assertNotEqual(OT, "not a type")
858
+ self.assertNotEqual(OT, 42)
859
+ self.assertNotEqual(OT, self.A)
860
+
861
+ def test_hash_consistency(self):
862
+ """Test hash consistency."""
863
+ OT = brainstate.mixin.OneOfTypes[self.A, self.B]
864
+ hash1 = hash(OT)
865
+ hash2 = hash(OT)
866
+ self.assertEqual(hash1, hash2)
867
+
868
+ def test_hash_order_independent(self):
869
+ """Test hash is order-independent."""
870
+ OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
871
+ OT2 = brainstate.mixin.OneOfTypes[self.B, self.A]
872
+ self.assertEqual(hash(OT1), hash(OT2))
873
+
874
+ def test_hash_different_for_different_types(self):
875
+ """Test different types have different hashes."""
876
+ OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
877
+ OT2 = brainstate.mixin.OneOfTypes[self.A, self.C]
878
+ self.assertNotEqual(hash(OT1), hash(OT2))
879
+
880
+ def test_hashable_in_set(self):
881
+ """Test OneOfTypes can be used in sets."""
882
+ OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
883
+ OT2 = brainstate.mixin.OneOfTypes[self.B, self.A]
884
+ OT3 = brainstate.mixin.OneOfTypes[self.A, self.C]
885
+
886
+ type_set = {OT1, OT2, OT3}
887
+ # OT1 and OT2 are equal, so set should have 2 elements
888
+ self.assertEqual(len(type_set), 2)
889
+ self.assertIn(OT1, type_set)
890
+ self.assertIn(OT2, type_set)
891
+ self.assertIn(OT3, type_set)
892
+
893
+ def test_as_dict_key(self):
894
+ """Test OneOfTypes can be used as dict keys."""
895
+ OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
896
+ OT2 = brainstate.mixin.OneOfTypes[self.B, self.A]
897
+
898
+ type_dict = {OT1: "A or B type"}
899
+ self.assertIn(OT1, type_dict)
900
+ self.assertIn(OT2, type_dict)
901
+ self.assertEqual(type_dict[OT2], "A or B type")
902
+
903
+ def test_pickle_roundtrip(self):
904
+ """Test pickling and unpickling with built-in types."""
905
+ import pickle
906
+ # Use built-in types since local classes can't be pickled
907
+ OT = brainstate.mixin.OneOfTypes[int, str]
908
+ pickled = pickle.dumps(OT)
909
+ unpickled = pickle.loads(pickled)
910
+ self.assertEqual(OT, unpickled)
911
+ self.assertEqual(hash(OT), hash(unpickled))
912
+
913
+ def test_pickle_preserves_isinstance(self):
914
+ """Test isinstance works after pickle with built-in types."""
915
+ import pickle
916
+ # Use built-in types for pickling
917
+ OT = brainstate.mixin.OneOfTypes[int, str]
918
+ pickled = pickle.dumps(OT)
919
+ unpickled = pickle.loads(pickled)
920
+
921
+ obj_a = 42
922
+ obj_b = "hello"
923
+
924
+ self.assertTrue(isinstance(obj_a, OT))
925
+ self.assertTrue(isinstance(obj_a, unpickled))
926
+ self.assertTrue(isinstance(obj_b, OT))
927
+ self.assertTrue(isinstance(obj_b, unpickled))
928
+
929
+ def test_isinstance_with_any_type(self):
930
+ """Test isinstance returns True if object is instance of any type."""
931
+ OT = brainstate.mixin.OneOfTypes[self.A, self.B]
932
+
933
+ obj_a = self.A()
934
+ obj_b = self.B()
935
+ obj_c = self.C()
936
+
937
+ self.assertTrue(isinstance(obj_a, OT))
938
+ self.assertTrue(isinstance(obj_b, OT))
939
+ self.assertFalse(isinstance(obj_c, OT))
940
+
941
+ def test_issubclass_with_any_type(self):
942
+ """Test issubclass returns True if class is subclass of any type."""
943
+ OT = brainstate.mixin.OneOfTypes[self.A, self.B]
944
+
945
+ class SubA(self.A):
946
+ pass
947
+
948
+ class SubB(self.B):
949
+ pass
950
+
951
+ self.assertTrue(issubclass(SubA, OT))
952
+ self.assertTrue(issubclass(SubB, OT))
953
+ self.assertTrue(issubclass(self.A, OT))
954
+ self.assertTrue(issubclass(self.B, OT))
955
+ self.assertFalse(issubclass(self.C, OT))
956
+
957
+ def test_multiple_types(self):
958
+ """Test OneOfTypes with more than 2 types."""
959
+ OT = brainstate.mixin.OneOfTypes[self.A, self.B, self.C]
960
+
961
+ obj_a = self.A()
962
+ obj_b = self.B()
963
+ obj_c = self.C()
964
+
965
+ self.assertTrue(isinstance(obj_a, OT))
966
+ self.assertTrue(isinstance(obj_b, OT))
967
+ self.assertTrue(isinstance(obj_c, OT))
968
+
969
+ def test_subscript_vs_call_syntax(self):
970
+ """Test subscript and call syntax produce equal results."""
971
+ OT_subscript = brainstate.mixin.OneOfTypes[self.A, self.B]
972
+ OT_call = brainstate.mixin.OneOfTypes(self.A, self.B)
973
+ self.assertEqual(OT_subscript, OT_call)
974
+ self.assertEqual(hash(OT_subscript), hash(OT_call))
975
+
976
+ def test_args_attribute(self):
977
+ """Test __args__ attribute contains correct types."""
978
+ OT = brainstate.mixin.OneOfTypes[self.A, self.B]
979
+ self.assertIn(self.A, OT.__args__)
980
+ self.assertIn(self.B, OT.__args__)
981
+ self.assertEqual(len(OT.__args__), 2)
982
+
983
+ def test_with_builtin_types(self):
984
+ """Test OneOfTypes with built-in types."""
985
+ OT = brainstate.mixin.OneOfTypes[int, float, str]
986
+
987
+ self.assertTrue(isinstance(42, OT))
988
+ self.assertTrue(isinstance(3.14, OT))
989
+ self.assertTrue(isinstance("hello", OT))
990
+ self.assertFalse(isinstance([], OT))
991
+
992
+
993
+ class TestJointTy:
994
+ def test1(self):
995
+ class Potassium:
996
+ pass
997
+
998
+ class Calcium:
999
+ pass
1000
+
1001
+ # Test JointTypes
1002
+ result1 = brainstate.mixin.JointTypes(Potassium, Calcium)
1003
+ result2 = brainstate.mixin.JointTypes[Potassium, Calcium]
1004
+ print(f'Function call: {result1}')
1005
+ print(f'Subscript: {result2}')
1006
+ print(f'Same? {result1 == result2}')
1007
+
1008
+ # Test OneOfTypes
1009
+ result3 = brainstate.mixin.OneOfTypes(Potassium, Calcium)
1010
+ result4 = brainstate.mixin.OneOfTypes[Potassium, Calcium]
1011
+ print(f'\nOneOfTypes Function call: {result3}')
1012
+ print(f'OneOfTypes Subscript: {result4}')
1013
+ print(f'Same? {result3 == result4}')
1014
+
1015
+
1016
+ if __name__ == '__main__':
1017
+ unittest.main()