brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/mixin_test.py CHANGED
@@ -1,1017 +1,1017 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """
17
- Comprehensive tests for brainstate.mixin module.
18
-
19
- This test suite covers all functionality in the mixin module including:
20
- - Base mixin classes
21
- - Parameter description and deferred instantiation
22
- - Type utilities (JointTypes, OneOfTypes)
23
- - Mode system (Mode, JointMode, Training, Batching)
24
- - Helper utilities (hashable, not_implemented, etc.)
25
- """
26
-
27
- import unittest
28
-
29
- import jax.numpy as jnp
30
-
31
- import brainstate
32
-
33
-
34
- class TestHashableFunction(unittest.TestCase):
35
- """Test the hashable utility function."""
36
-
37
- def test_hashable_primitives(self):
38
- """Test hashable with primitive types."""
39
- self.assertTrue(brainstate.mixin.hashable(42))
40
- self.assertTrue(brainstate.mixin.hashable(3.14))
41
- self.assertTrue(brainstate.mixin.hashable("string"))
42
- self.assertTrue(brainstate.mixin.hashable(True))
43
- self.assertTrue(brainstate.mixin.hashable(None))
44
-
45
- def test_hashable_tuples(self):
46
- """Test hashable with tuples."""
47
- self.assertTrue(brainstate.mixin.hashable((1, 2, 3)))
48
- self.assertTrue(brainstate.mixin.hashable(("a", "b")))
49
- self.assertTrue(brainstate.mixin.hashable(()))
50
-
51
- def test_non_hashable_types(self):
52
- """Test non-hashable types."""
53
- self.assertFalse(brainstate.mixin.hashable([1, 2, 3]))
54
- self.assertFalse(brainstate.mixin.hashable({"key": "value"}))
55
- self.assertFalse(brainstate.mixin.hashable({1, 2, 3}))
56
- self.assertFalse(brainstate.mixin.hashable(jnp.array([1, 2, 3])))
57
-
58
-
59
- class TestMixin(unittest.TestCase):
60
- """Test the base Mixin class."""
61
-
62
- def test_mixin_exists(self):
63
- """Test that Mixin class exists."""
64
- self.assertTrue(brainstate.mixin.Mixin)
65
-
66
- def test_mixin_inheritance(self):
67
- """Test creating a custom mixin."""
68
-
69
- class LoggingMixin(brainstate.mixin.Mixin):
70
- def log(self, message):
71
- return f"[LOG] {message}"
72
-
73
- class Component(LoggingMixin):
74
- pass
75
-
76
- comp = Component()
77
- self.assertEqual(comp.log("test"), "[LOG] test")
78
-
79
- def test_mixin_multiple_inheritance(self):
80
- """Test multiple mixin inheritance."""
81
-
82
- class MixinA(brainstate.mixin.Mixin):
83
- def method_a(self):
84
- return "A"
85
-
86
- class MixinB(brainstate.mixin.Mixin):
87
- def method_b(self):
88
- return "B"
89
-
90
- class Component(MixinA, MixinB):
91
- pass
92
-
93
- comp = Component()
94
- self.assertEqual(comp.method_a(), "A")
95
- self.assertEqual(comp.method_b(), "B")
96
-
97
-
98
- class TestParamDesc(unittest.TestCase):
99
- """Test ParamDesc mixin and ParamDescriber."""
100
-
101
- def test_param_desc_basic(self):
102
- """Test basic ParamDesc functionality."""
103
-
104
- class Network(brainstate.mixin.ParamDesc):
105
- def __init__(self, size, learning_rate=0.01):
106
- self.size = size
107
- self.learning_rate = learning_rate
108
-
109
- # Test desc method exists
110
- self.assertTrue(hasattr(Network, 'desc'))
111
-
112
- # Create a descriptor
113
- desc = Network.desc(size=100)
114
- self.assertIsInstance(desc, brainstate.mixin.ParamDescriber)
115
-
116
- def test_param_describer_instantiation(self):
117
- """Test ParamDescriber can create instances."""
118
-
119
- class Network(brainstate.mixin.ParamDesc):
120
- def __init__(self, size, learning_rate=0.01):
121
- self.size = size
122
- self.learning_rate = learning_rate
123
-
124
- desc = Network.desc(size=100, learning_rate=0.001)
125
-
126
- # Create instances
127
- net1 = desc()
128
- self.assertEqual(net1.size, 100)
129
- self.assertEqual(net1.learning_rate, 0.001)
130
-
131
- # Create with overrides
132
- net2 = desc(learning_rate=0.005)
133
- self.assertEqual(net2.size, 100)
134
- self.assertEqual(net2.learning_rate, 0.005)
135
-
136
- def test_param_describer_init_method(self):
137
- """Test ParamDescriber.init() method."""
138
-
139
- class Model(brainstate.mixin.ParamDesc):
140
- def __init__(self, value):
141
- self.value = value
142
-
143
- desc = Model.desc(value=42)
144
- instance = desc.init()
145
- self.assertEqual(instance.value, 42)
146
-
147
- def test_param_describer_identifier(self):
148
- """Test ParamDescriber identifier property."""
149
-
150
- class Model(brainstate.mixin.ParamDesc):
151
- def __init__(self, x, y=10):
152
- self.x = x
153
- self.y = y
154
-
155
- desc = Model.desc(x=5, y=20)
156
- identifier = desc.identifier
157
-
158
- # Identifier should be a tuple
159
- self.assertIsInstance(identifier, tuple)
160
- self.assertEqual(len(identifier), 3)
161
- self.assertEqual(identifier[0], Model)
162
-
163
- # Identifier should be read-only
164
- with self.assertRaises(AttributeError):
165
- desc.identifier = "new"
166
-
167
- def test_param_describer_class_getitem(self):
168
- """Test ParamDescriber[Class] notation."""
169
-
170
- class Model:
171
- def __init__(self, value):
172
- self.value = value
173
-
174
- desc = brainstate.mixin.ParamDescriber[Model]
175
- self.assertIsInstance(desc, brainstate.mixin.ParamDescriber)
176
- self.assertEqual(desc.cls, Model)
177
-
178
- def test_no_subclass_meta(self):
179
- """Test that ParamDescriber cannot be subclassed."""
180
-
181
- with self.assertRaises(TypeError):
182
- class CustomDescriber(brainstate.mixin.ParamDescriber):
183
- pass
184
-
185
-
186
- class TestHashableDict(unittest.TestCase):
187
- """Test HashableDict class."""
188
-
189
- def test_hashable_dict_basic(self):
190
- """Test basic HashableDict functionality."""
191
- d = brainstate.mixin.HashableDict({"a": 1, "b": 2})
192
- h = hash(d)
193
- self.assertIsInstance(h, int)
194
-
195
- def test_hashable_dict_with_arrays(self):
196
- """Test HashableDict with non-hashable values."""
197
- d = brainstate.mixin.HashableDict({
198
- "array": jnp.array([1, 2, 3]),
199
- "value": 42
200
- })
201
- h = hash(d)
202
- self.assertIsInstance(h, int)
203
-
204
- def test_hashable_dict_consistency(self):
205
- """Test that equal dicts have equal hashes."""
206
- d1 = brainstate.mixin.HashableDict({"a": 1, "b": 2})
207
- d2 = brainstate.mixin.HashableDict({"b": 2, "a": 1})
208
- self.assertEqual(hash(d1), hash(d2))
209
-
210
- def test_hashable_dict_usable_as_key(self):
211
- """Test that HashableDict can be used as dict key."""
212
- d = brainstate.mixin.HashableDict({"x": 10})
213
- cache = {d: "result"}
214
- self.assertEqual(cache[d], "result")
215
-
216
-
217
- class TestJointTypes(unittest.TestCase):
218
- """Test JointTypes functionality."""
219
-
220
- def test_joint_types_basic(self):
221
- """Test basic JointTypes creation."""
222
-
223
- class A:
224
- pass
225
-
226
- class B:
227
- pass
228
-
229
- JointAB = brainstate.mixin.JointTypes(A, B)
230
- self.assertIsNotNone(JointAB)
231
-
232
- def test_joint_types_isinstance(self):
233
- """Test isinstance with JointTypes."""
234
-
235
- class Serializable:
236
- def save(self):
237
- pass
238
-
239
- class Visualizable:
240
- def plot(self):
241
- pass
242
-
243
- Combined = brainstate.mixin.JointTypes(Serializable, Visualizable)
244
-
245
- class Model(Serializable, Visualizable):
246
- def save(self):
247
- return "saved"
248
-
249
- def plot(self):
250
- return "plotted"
251
-
252
- model = Model()
253
- self.assertTrue(isinstance(model, Combined))
254
-
255
- def test_joint_types_issubclass(self):
256
- """Test issubclass with JointTypes."""
257
-
258
- class A:
259
- pass
260
-
261
- class B:
262
- pass
263
-
264
- JointAB = brainstate.mixin.JointTypes(A, B)
265
-
266
- class C(A, B):
267
- pass
268
-
269
- self.assertTrue(issubclass(C, JointAB))
270
-
271
- def test_joint_types_single_type(self):
272
- """Test JointTypes with single type returns that type."""
273
-
274
- class A:
275
- pass
276
-
277
- result = brainstate.mixin.JointTypes(A)
278
- self.assertEqual(result, A)
279
-
280
- def test_joint_types_no_types(self):
281
- """Test JointTypes with no types raises error."""
282
- with self.assertRaises(TypeError):
283
- brainstate.mixin.JointTypes()
284
-
285
- def test_joint_types_removes_duplicates(self):
286
- """Test that JointTypes removes duplicate types."""
287
-
288
- class A:
289
- pass
290
-
291
- # Should handle duplicates gracefully
292
- JointA = brainstate.mixin.JointTypes(A, A, A)
293
- self.assertEqual(JointA, A)
294
-
295
-
296
- class TestOneOfTypes(unittest.TestCase):
297
- """Test OneOfTypes functionality."""
298
-
299
- def test_one_of_types_basic(self):
300
- """Test basic OneOfTypes creation."""
301
- IntOrFloat = brainstate.mixin.OneOfTypes(int, float)
302
- self.assertIsNotNone(IntOrFloat)
303
-
304
- def test_one_of_types_isinstance(self):
305
- """Test isinstance with OneOfTypes."""
306
- NumType = brainstate.mixin.OneOfTypes(int, float)
307
-
308
- self.assertTrue(isinstance(42, NumType))
309
- self.assertTrue(isinstance(3.14, NumType))
310
- self.assertFalse(isinstance("hello", NumType))
311
-
312
- def test_one_of_types_single_type(self):
313
- """Test OneOfTypes with single type returns that type."""
314
- result = brainstate.mixin.OneOfTypes(int)
315
- self.assertEqual(result, int)
316
-
317
- def test_one_of_types_no_types(self):
318
- """Test OneOfTypes with no types raises error."""
319
- with self.assertRaises(TypeError):
320
- brainstate.mixin.OneOfTypes()
321
-
322
- def test_one_of_types_with_none(self):
323
- """Test OneOfTypes with None for optional types."""
324
- MaybeInt = brainstate.mixin.OneOfTypes(int, type(None))
325
-
326
- self.assertTrue(isinstance(42, MaybeInt))
327
- self.assertTrue(isinstance(None, MaybeInt))
328
- self.assertFalse(isinstance("hello", MaybeInt))
329
-
330
-
331
-
332
- class TestNotImplemented(unittest.TestCase):
333
- """Test not_implemented decorator."""
334
-
335
- def test_not_implemented_decorator(self):
336
- """Test not_implemented decorator marks functions."""
337
-
338
- @brainstate.mixin.not_implemented
339
- def my_function():
340
- pass
341
-
342
- self.assertTrue(hasattr(my_function, 'not_implemented'))
343
- self.assertTrue(my_function.not_implemented)
344
-
345
- def test_not_implemented_raises(self):
346
- """Test not_implemented decorator raises error when called."""
347
-
348
- @brainstate.mixin.not_implemented
349
- def my_function():
350
- pass
351
-
352
- with self.assertRaises(NotImplementedError) as cm:
353
- my_function()
354
-
355
- self.assertIn("my_function", str(cm.exception))
356
-
357
-
358
- class TestMode(unittest.TestCase):
359
- """Test Mode base class."""
360
-
361
- def test_mode_creation(self):
362
- """Test basic Mode creation."""
363
- mode = brainstate.mixin.Mode()
364
- self.assertIsNotNone(mode)
365
-
366
- def test_mode_repr(self):
367
- """Test Mode string representation."""
368
- mode = brainstate.mixin.Mode()
369
- self.assertEqual(repr(mode), "Mode")
370
-
371
- def test_mode_equality(self):
372
- """Test Mode equality comparison."""
373
- mode1 = brainstate.mixin.Mode()
374
- mode2 = brainstate.mixin.Mode()
375
- self.assertEqual(mode1, mode2)
376
-
377
- def test_mode_is_a(self):
378
- """Test Mode.is_a() method."""
379
- mode = brainstate.mixin.Mode()
380
- self.assertTrue(mode.is_a(brainstate.mixin.Mode))
381
- self.assertFalse(mode.is_a(brainstate.mixin.Training))
382
-
383
- def test_mode_has(self):
384
- """Test Mode.has() method."""
385
- mode = brainstate.mixin.Mode()
386
- self.assertTrue(mode.has(brainstate.mixin.Mode))
387
- self.assertFalse(mode.has(brainstate.mixin.Training))
388
-
389
- def test_custom_mode(self):
390
- """Test creating custom mode."""
391
-
392
- class CustomMode(brainstate.mixin.Mode):
393
- def __init__(self, value):
394
- self.value = value
395
-
396
- mode = CustomMode(42)
397
- self.assertEqual(mode.value, 42)
398
- self.assertTrue(mode.has(brainstate.mixin.Mode))
399
-
400
-
401
- class TestTraining(unittest.TestCase):
402
- """Test Training mode."""
403
-
404
- def test_training_creation(self):
405
- """Test Training mode creation."""
406
- training = brainstate.mixin.Training()
407
- self.assertIsNotNone(training)
408
-
409
- def test_training_is_mode(self):
410
- """Test Training is a Mode."""
411
- training = brainstate.mixin.Training()
412
- self.assertTrue(training.has(brainstate.mixin.Mode))
413
-
414
- def test_training_is_a(self):
415
- """Test Training.is_a() method."""
416
- training = brainstate.mixin.Training()
417
- self.assertTrue(training.is_a(brainstate.mixin.Training))
418
- self.assertFalse(training.is_a(brainstate.mixin.Batching))
419
-
420
- def test_training_has(self):
421
- """Test Training.has() method."""
422
- training = brainstate.mixin.Training()
423
- self.assertTrue(training.has(brainstate.mixin.Training))
424
- self.assertFalse(training.has(brainstate.mixin.Batching))
425
-
426
- def test_training_joint_types(self):
427
- """Test Training with JointTypes."""
428
- training = brainstate.mixin.Training()
429
- self.assertTrue(training.is_a(brainstate.mixin.JointTypes(brainstate.mixin.Training)))
430
- self.assertTrue(training.has(brainstate.mixin.JointTypes(brainstate.mixin.Training)))
431
-
432
-
433
- class TestBatching(unittest.TestCase):
434
- """Test Batching mode."""
435
-
436
- def test_batching_creation(self):
437
- """Test Batching mode creation."""
438
- batching = brainstate.mixin.Batching()
439
- self.assertIsNotNone(batching)
440
-
441
- def test_batching_default_params(self):
442
- """Test Batching default parameters."""
443
- batching = brainstate.mixin.Batching()
444
- self.assertEqual(batching.batch_size, 1)
445
- self.assertEqual(batching.batch_axis, 0)
446
-
447
- def test_batching_custom_params(self):
448
- """Test Batching with custom parameters."""
449
- batching = brainstate.mixin.Batching(batch_size=32, batch_axis=1)
450
- self.assertEqual(batching.batch_size, 32)
451
- self.assertEqual(batching.batch_axis, 1)
452
-
453
- def test_batching_repr(self):
454
- """Test Batching string representation."""
455
- batching = brainstate.mixin.Batching(batch_size=64, batch_axis=0)
456
- self.assertIn("64", repr(batching))
457
- self.assertIn("0", repr(batching))
458
-
459
- def test_batching_is_mode(self):
460
- """Test Batching is a Mode."""
461
- batching = brainstate.mixin.Batching()
462
- self.assertTrue(batching.has(brainstate.mixin.Mode))
463
-
464
- def test_batching_is_a(self):
465
- """Test Batching.is_a() method."""
466
- batching = brainstate.mixin.Batching()
467
- self.assertTrue(batching.is_a(brainstate.mixin.Batching))
468
- self.assertFalse(batching.is_a(brainstate.mixin.Training))
469
-
470
- def test_batching_has(self):
471
- """Test Batching.has() method."""
472
- batching = brainstate.mixin.Batching()
473
- self.assertTrue(batching.has(brainstate.mixin.Batching))
474
- self.assertFalse(batching.has(brainstate.mixin.Training))
475
-
476
-
477
- class TestJointMode(unittest.TestCase):
478
- """Test JointMode functionality."""
479
-
480
- def test_joint_mode_creation(self):
481
- """Test JointMode creation."""
482
- training = brainstate.mixin.Training()
483
- batching = brainstate.mixin.Batching()
484
- joint = brainstate.mixin.JointMode(training, batching)
485
- self.assertIsNotNone(joint)
486
-
487
- def test_joint_mode_repr(self):
488
- """Test JointMode string representation."""
489
- training = brainstate.mixin.Training()
490
- batching = brainstate.mixin.Batching(batch_size=32)
491
- joint = brainstate.mixin.JointMode(training, batching)
492
-
493
- repr_str = repr(joint)
494
- self.assertIn("JointMode", repr_str)
495
- self.assertIn("Training", repr_str)
496
- self.assertIn("Batching", repr_str)
497
-
498
- def test_joint_mode_has(self):
499
- """Test JointMode.has() method."""
500
- training = brainstate.mixin.Training()
501
- batching = brainstate.mixin.Batching()
502
- joint = brainstate.mixin.JointMode(training, batching)
503
-
504
- self.assertTrue(joint.has(brainstate.mixin.Training))
505
- self.assertTrue(joint.has(brainstate.mixin.Batching))
506
- self.assertTrue(joint.has(brainstate.mixin.Mode))
507
-
508
- def test_joint_mode_is_a(self):
509
- """Test JointMode.is_a() method."""
510
- training = brainstate.mixin.Training()
511
- batching = brainstate.mixin.Batching()
512
- joint = brainstate.mixin.JointMode(training, batching)
513
-
514
- # JointMode.is_a() works by checking if the JointTypes of the mode types
515
- # matches the expected type. This is a complex comparison.
516
- # For practical use, test that it correctly identifies single types
517
- self.assertFalse(joint.is_a(brainstate.mixin.Training)) # Not just Training
518
- self.assertFalse(joint.is_a(brainstate.mixin.Batching)) # Not just Batching
519
-
520
- # But a single mode joint should match
521
- single_joint = brainstate.mixin.JointMode(training)
522
- self.assertTrue(single_joint.is_a(brainstate.mixin.Training))
523
-
524
- def test_joint_mode_single_mode(self):
525
- """Test JointMode with single mode."""
526
- batching = brainstate.mixin.Batching()
527
- joint = brainstate.mixin.JointMode(batching)
528
-
529
- self.assertTrue(joint.has(brainstate.mixin.Batching))
530
- self.assertTrue(joint.is_a(brainstate.mixin.Batching))
531
-
532
- def test_joint_mode_attribute_access(self):
533
- """Test JointMode attribute delegation."""
534
- batching = brainstate.mixin.Batching(batch_size=32, batch_axis=1)
535
- training = brainstate.mixin.Training()
536
- joint = brainstate.mixin.JointMode(batching, training)
537
-
538
- # Should access batching attributes
539
- self.assertEqual(joint.batch_size, 32)
540
- self.assertEqual(joint.batch_axis, 1)
541
-
542
- def test_joint_mode_invalid_type(self):
543
- """Test JointMode with non-Mode raises error."""
544
- with self.assertRaises(TypeError):
545
- brainstate.mixin.JointMode("not a mode")
546
-
547
- def test_joint_mode_modes_attribute(self):
548
- """Test accessing modes attribute."""
549
- training = brainstate.mixin.Training()
550
- batching = brainstate.mixin.Batching()
551
- joint = brainstate.mixin.JointMode(training, batching)
552
-
553
- self.assertEqual(len(joint.modes), 2)
554
- self.assertIn(training, joint.modes)
555
- self.assertIn(batching, joint.modes)
556
-
557
- def test_joint_mode_types_attribute(self):
558
- """Test accessing types attribute."""
559
- training = brainstate.mixin.Training()
560
- batching = brainstate.mixin.Batching()
561
- joint = brainstate.mixin.JointMode(training, batching)
562
-
563
- self.assertEqual(len(joint.types), 2)
564
- self.assertIn(brainstate.mixin.Training, joint.types)
565
- self.assertIn(brainstate.mixin.Batching, joint.types)
566
-
567
-
568
- class TestIntegration(unittest.TestCase):
569
- """Integration tests combining multiple features."""
570
-
571
- def test_param_desc_with_modes(self):
572
- """Test ParamDesc with Mode system."""
573
-
574
- class Model(brainstate.mixin.ParamDesc):
575
- def __init__(self, size, mode=None):
576
- self.size = size
577
- self.mode = mode if mode is not None else brainstate.mixin.Mode()
578
-
579
- # Create descriptor with training mode
580
- train_model_desc = Model.desc(size=100, mode=brainstate.mixin.Training())
581
- model = train_model_desc()
582
-
583
- self.assertEqual(model.size, 100)
584
- self.assertTrue(model.mode.has(brainstate.mixin.Training))
585
-
586
- def test_joint_types_with_multiple_mixins(self):
587
- """Test JointTypes with multiple mixin classes."""
588
-
589
- class Serializable(brainstate.mixin.Mixin):
590
- def save(self):
591
- return "saved"
592
-
593
- class Trainable(brainstate.mixin.Mixin):
594
- def train(self):
595
- return "trained"
596
-
597
- class Evaluable(brainstate.mixin.Mixin):
598
- def evaluate(self):
599
- return "evaluated"
600
-
601
- FullModel = brainstate.mixin.JointTypes(Serializable, Trainable, Evaluable)
602
-
603
- class MyModel(Serializable, Trainable, Evaluable):
604
- pass
605
-
606
- model = MyModel()
607
- self.assertTrue(isinstance(model, FullModel))
608
- self.assertEqual(model.save(), "saved")
609
- self.assertEqual(model.train(), "trained")
610
- self.assertEqual(model.evaluate(), "evaluated")
611
-
612
- def test_complex_mode_scenario(self):
613
- """Test complex scenario with multiple modes."""
614
-
615
- class NeuralNetwork:
616
- def __init__(self):
617
- self.mode = None
618
-
619
- def set_mode(self, mode):
620
- self.mode = mode
621
-
622
- def forward(self, x):
623
- if self.mode is None:
624
- return x
625
-
626
- if self.mode.has(brainstate.mixin.Training):
627
- # Add noise during training
628
- x = x + 0.1
629
-
630
- if self.mode.has(brainstate.mixin.Batching):
631
- # Process in batches
632
- batch_size = self.mode.batch_size
633
- # Just return with batch info for testing
634
- return x, batch_size
635
-
636
- return x
637
-
638
- net = NeuralNetwork()
639
-
640
- # Test evaluation mode
641
- result = net.forward(1.0)
642
- self.assertEqual(result, 1.0)
643
-
644
- # Test training mode
645
- net.set_mode(brainstate.mixin.Training())
646
- result = net.forward(1.0)
647
- self.assertAlmostEqual(result, 1.1)
648
-
649
- # Test joint mode
650
- training = brainstate.mixin.Training()
651
- batching = brainstate.mixin.Batching(batch_size=32)
652
- net.set_mode(brainstate.mixin.JointMode(training, batching))
653
-
654
- result, batch_size = net.forward(1.0)
655
- self.assertAlmostEqual(result, 1.1)
656
- self.assertEqual(batch_size, 32)
657
-
658
-
659
- class TestJointTypesComprehensive(unittest.TestCase):
660
- """Comprehensive tests for JointTypes special methods and functionality."""
661
-
662
- def setUp(self):
663
- """Set up test classes."""
664
- class A:
665
- pass
666
-
667
- class B:
668
- pass
669
-
670
- class C:
671
- pass
672
-
673
- self.A = A
674
- self.B = B
675
- self.C = C
676
-
677
- def test_repr(self):
678
- """Test __repr__ method."""
679
- JT = brainstate.mixin.JointTypes[self.A, self.B]
680
- repr_str = repr(JT)
681
- self.assertIn('JointTypes', repr_str)
682
- self.assertIn('A', repr_str)
683
- self.assertIn('B', repr_str)
684
-
685
- def test_eq_same_order(self):
686
- """Test equality with same type order."""
687
- JT1 = brainstate.mixin.JointTypes[self.A, self.B]
688
- JT2 = brainstate.mixin.JointTypes[self.A, self.B]
689
- self.assertEqual(JT1, JT2)
690
-
691
- def test_eq_different_order(self):
692
- """Test equality with different type order."""
693
- JT1 = brainstate.mixin.JointTypes[self.A, self.B]
694
- JT2 = brainstate.mixin.JointTypes[self.B, self.A]
695
- self.assertEqual(JT1, JT2)
696
-
697
- def test_eq_different_types(self):
698
- """Test inequality with different types."""
699
- JT1 = brainstate.mixin.JointTypes[self.A, self.B]
700
- JT2 = brainstate.mixin.JointTypes[self.A, self.C]
701
- self.assertNotEqual(JT1, JT2)
702
-
703
- def test_eq_with_non_jointtypes(self):
704
- """Test equality with non-JointTypes object."""
705
- JT = brainstate.mixin.JointTypes[self.A, self.B]
706
- self.assertNotEqual(JT, "not a type")
707
- self.assertNotEqual(JT, 42)
708
- self.assertNotEqual(JT, self.A)
709
-
710
- def test_hash_consistency(self):
711
- """Test hash consistency."""
712
- JT = brainstate.mixin.JointTypes[self.A, self.B]
713
- hash1 = hash(JT)
714
- hash2 = hash(JT)
715
- self.assertEqual(hash1, hash2)
716
-
717
- def test_hash_order_independent(self):
718
- """Test hash is order-independent."""
719
- JT1 = brainstate.mixin.JointTypes[self.A, self.B]
720
- JT2 = brainstate.mixin.JointTypes[self.B, self.A]
721
- self.assertEqual(hash(JT1), hash(JT2))
722
-
723
- def test_hash_different_for_different_types(self):
724
- """Test different types have different hashes."""
725
- JT1 = brainstate.mixin.JointTypes[self.A, self.B]
726
- JT2 = brainstate.mixin.JointTypes[self.A, self.C]
727
- # Note: hash collision is possible but unlikely for different types
728
- self.assertNotEqual(hash(JT1), hash(JT2))
729
-
730
- def test_hashable_in_set(self):
731
- """Test JointTypes can be used in sets."""
732
- JT1 = brainstate.mixin.JointTypes[self.A, self.B]
733
- JT2 = brainstate.mixin.JointTypes[self.B, self.A]
734
- JT3 = brainstate.mixin.JointTypes[self.A, self.C]
735
-
736
- type_set = {JT1, JT2, JT3}
737
- # JT1 and JT2 are equal, so set should have 2 elements
738
- self.assertEqual(len(type_set), 2)
739
- self.assertIn(JT1, type_set)
740
- self.assertIn(JT2, type_set)
741
- self.assertIn(JT3, type_set)
742
-
743
- def test_as_dict_key(self):
744
- """Test JointTypes can be used as dict keys."""
745
- JT1 = brainstate.mixin.JointTypes[self.A, self.B]
746
- JT2 = brainstate.mixin.JointTypes[self.B, self.A]
747
-
748
- type_dict = {JT1: "AB type"}
749
- self.assertIn(JT1, type_dict)
750
- # JT2 should work as key since it's equal to JT1
751
- self.assertIn(JT2, type_dict)
752
- self.assertEqual(type_dict[JT2], "AB type")
753
-
754
- def test_pickle_roundtrip(self):
755
- """Test pickling and unpickling with built-in types."""
756
- import pickle
757
- # Use built-in types since local classes can't be pickled
758
- JT = brainstate.mixin.JointTypes[int, str]
759
- pickled = pickle.dumps(JT)
760
- unpickled = pickle.loads(pickled)
761
- self.assertEqual(JT, unpickled)
762
- self.assertEqual(hash(JT), hash(unpickled))
763
-
764
- def test_pickle_preserves_isinstance(self):
765
- """Test isinstance works after pickle with built-in types."""
766
- import pickle
767
-
768
- class IntStr(int):
769
- """A class that inherits from int."""
770
- pass
771
-
772
- # Use built-in types for pickling
773
- JT = brainstate.mixin.JointTypes[int, object]
774
- pickled = pickle.dumps(JT)
775
- unpickled = pickle.loads(pickled)
776
-
777
- obj = 42
778
- self.assertTrue(isinstance(obj, JT))
779
- self.assertTrue(isinstance(obj, unpickled))
780
-
781
- def test_multiple_types(self):
782
- """Test JointTypes with more than 2 types."""
783
- JT = brainstate.mixin.JointTypes[self.A, self.B, self.C]
784
-
785
- class ABC(self.A, self.B, self.C):
786
- pass
787
-
788
- self.assertTrue(issubclass(ABC, JT))
789
-
790
- class AB(self.A, self.B):
791
- pass
792
-
793
- self.assertFalse(issubclass(AB, JT))
794
-
795
- def test_subscript_vs_call_syntax(self):
796
- """Test subscript and call syntax produce equal results."""
797
- JT_subscript = brainstate.mixin.JointTypes[self.A, self.B]
798
- JT_call = brainstate.mixin.JointTypes(self.A, self.B)
799
- self.assertEqual(JT_subscript, JT_call)
800
- self.assertEqual(hash(JT_subscript), hash(JT_call))
801
-
802
- def test_args_attribute(self):
803
- """Test __args__ attribute contains correct types."""
804
- JT = brainstate.mixin.JointTypes[self.A, self.B]
805
- self.assertIn(self.A, JT.__args__)
806
- self.assertIn(self.B, JT.__args__)
807
- self.assertEqual(len(JT.__args__), 2)
808
-
809
-
810
- class TestOneOfTypesComprehensive(unittest.TestCase):
811
- """Comprehensive tests for OneOfTypes special methods and functionality."""
812
-
813
- def setUp(self):
814
- """Set up test classes."""
815
- class A:
816
- pass
817
-
818
- class B:
819
- pass
820
-
821
- class C:
822
- pass
823
-
824
- self.A = A
825
- self.B = B
826
- self.C = C
827
-
828
- def test_repr(self):
829
- """Test __repr__ method."""
830
- OT = brainstate.mixin.OneOfTypes[self.A, self.B]
831
- repr_str = repr(OT)
832
- self.assertIn('OneOfTypes', repr_str)
833
- self.assertIn('A', repr_str)
834
- self.assertIn('B', repr_str)
835
-
836
- def test_eq_same_order(self):
837
- """Test equality with same type order."""
838
- OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
839
- OT2 = brainstate.mixin.OneOfTypes[self.A, self.B]
840
- self.assertEqual(OT1, OT2)
841
-
842
- def test_eq_different_order(self):
843
- """Test equality with different type order."""
844
- OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
845
- OT2 = brainstate.mixin.OneOfTypes[self.B, self.A]
846
- self.assertEqual(OT1, OT2)
847
-
848
- def test_eq_different_types(self):
849
- """Test inequality with different types."""
850
- OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
851
- OT2 = brainstate.mixin.OneOfTypes[self.A, self.C]
852
- self.assertNotEqual(OT1, OT2)
853
-
854
- def test_eq_with_non_oneoftypes(self):
855
- """Test equality with non-OneOfTypes object."""
856
- OT = brainstate.mixin.OneOfTypes[self.A, self.B]
857
- self.assertNotEqual(OT, "not a type")
858
- self.assertNotEqual(OT, 42)
859
- self.assertNotEqual(OT, self.A)
860
-
861
- def test_hash_consistency(self):
862
- """Test hash consistency."""
863
- OT = brainstate.mixin.OneOfTypes[self.A, self.B]
864
- hash1 = hash(OT)
865
- hash2 = hash(OT)
866
- self.assertEqual(hash1, hash2)
867
-
868
- def test_hash_order_independent(self):
869
- """Test hash is order-independent."""
870
- OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
871
- OT2 = brainstate.mixin.OneOfTypes[self.B, self.A]
872
- self.assertEqual(hash(OT1), hash(OT2))
873
-
874
- def test_hash_different_for_different_types(self):
875
- """Test different types have different hashes."""
876
- OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
877
- OT2 = brainstate.mixin.OneOfTypes[self.A, self.C]
878
- self.assertNotEqual(hash(OT1), hash(OT2))
879
-
880
- def test_hashable_in_set(self):
881
- """Test OneOfTypes can be used in sets."""
882
- OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
883
- OT2 = brainstate.mixin.OneOfTypes[self.B, self.A]
884
- OT3 = brainstate.mixin.OneOfTypes[self.A, self.C]
885
-
886
- type_set = {OT1, OT2, OT3}
887
- # OT1 and OT2 are equal, so set should have 2 elements
888
- self.assertEqual(len(type_set), 2)
889
- self.assertIn(OT1, type_set)
890
- self.assertIn(OT2, type_set)
891
- self.assertIn(OT3, type_set)
892
-
893
- def test_as_dict_key(self):
894
- """Test OneOfTypes can be used as dict keys."""
895
- OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
896
- OT2 = brainstate.mixin.OneOfTypes[self.B, self.A]
897
-
898
- type_dict = {OT1: "A or B type"}
899
- self.assertIn(OT1, type_dict)
900
- self.assertIn(OT2, type_dict)
901
- self.assertEqual(type_dict[OT2], "A or B type")
902
-
903
- def test_pickle_roundtrip(self):
904
- """Test pickling and unpickling with built-in types."""
905
- import pickle
906
- # Use built-in types since local classes can't be pickled
907
- OT = brainstate.mixin.OneOfTypes[int, str]
908
- pickled = pickle.dumps(OT)
909
- unpickled = pickle.loads(pickled)
910
- self.assertEqual(OT, unpickled)
911
- self.assertEqual(hash(OT), hash(unpickled))
912
-
913
- def test_pickle_preserves_isinstance(self):
914
- """Test isinstance works after pickle with built-in types."""
915
- import pickle
916
- # Use built-in types for pickling
917
- OT = brainstate.mixin.OneOfTypes[int, str]
918
- pickled = pickle.dumps(OT)
919
- unpickled = pickle.loads(pickled)
920
-
921
- obj_a = 42
922
- obj_b = "hello"
923
-
924
- self.assertTrue(isinstance(obj_a, OT))
925
- self.assertTrue(isinstance(obj_a, unpickled))
926
- self.assertTrue(isinstance(obj_b, OT))
927
- self.assertTrue(isinstance(obj_b, unpickled))
928
-
929
- def test_isinstance_with_any_type(self):
930
- """Test isinstance returns True if object is instance of any type."""
931
- OT = brainstate.mixin.OneOfTypes[self.A, self.B]
932
-
933
- obj_a = self.A()
934
- obj_b = self.B()
935
- obj_c = self.C()
936
-
937
- self.assertTrue(isinstance(obj_a, OT))
938
- self.assertTrue(isinstance(obj_b, OT))
939
- self.assertFalse(isinstance(obj_c, OT))
940
-
941
- def test_issubclass_with_any_type(self):
942
- """Test issubclass returns True if class is subclass of any type."""
943
- OT = brainstate.mixin.OneOfTypes[self.A, self.B]
944
-
945
- class SubA(self.A):
946
- pass
947
-
948
- class SubB(self.B):
949
- pass
950
-
951
- self.assertTrue(issubclass(SubA, OT))
952
- self.assertTrue(issubclass(SubB, OT))
953
- self.assertTrue(issubclass(self.A, OT))
954
- self.assertTrue(issubclass(self.B, OT))
955
- self.assertFalse(issubclass(self.C, OT))
956
-
957
- def test_multiple_types(self):
958
- """Test OneOfTypes with more than 2 types."""
959
- OT = brainstate.mixin.OneOfTypes[self.A, self.B, self.C]
960
-
961
- obj_a = self.A()
962
- obj_b = self.B()
963
- obj_c = self.C()
964
-
965
- self.assertTrue(isinstance(obj_a, OT))
966
- self.assertTrue(isinstance(obj_b, OT))
967
- self.assertTrue(isinstance(obj_c, OT))
968
-
969
- def test_subscript_vs_call_syntax(self):
970
- """Test subscript and call syntax produce equal results."""
971
- OT_subscript = brainstate.mixin.OneOfTypes[self.A, self.B]
972
- OT_call = brainstate.mixin.OneOfTypes(self.A, self.B)
973
- self.assertEqual(OT_subscript, OT_call)
974
- self.assertEqual(hash(OT_subscript), hash(OT_call))
975
-
976
- def test_args_attribute(self):
977
- """Test __args__ attribute contains correct types."""
978
- OT = brainstate.mixin.OneOfTypes[self.A, self.B]
979
- self.assertIn(self.A, OT.__args__)
980
- self.assertIn(self.B, OT.__args__)
981
- self.assertEqual(len(OT.__args__), 2)
982
-
983
- def test_with_builtin_types(self):
984
- """Test OneOfTypes with built-in types."""
985
- OT = brainstate.mixin.OneOfTypes[int, float, str]
986
-
987
- self.assertTrue(isinstance(42, OT))
988
- self.assertTrue(isinstance(3.14, OT))
989
- self.assertTrue(isinstance("hello", OT))
990
- self.assertFalse(isinstance([], OT))
991
-
992
-
993
- class TestJointTy:
994
- def test1(self):
995
- class Potassium:
996
- pass
997
-
998
- class Calcium:
999
- pass
1000
-
1001
- # Test JointTypes
1002
- result1 = brainstate.mixin.JointTypes(Potassium, Calcium)
1003
- result2 = brainstate.mixin.JointTypes[Potassium, Calcium]
1004
- print(f'Function call: {result1}')
1005
- print(f'Subscript: {result2}')
1006
- print(f'Same? {result1 == result2}')
1007
-
1008
- # Test OneOfTypes
1009
- result3 = brainstate.mixin.OneOfTypes(Potassium, Calcium)
1010
- result4 = brainstate.mixin.OneOfTypes[Potassium, Calcium]
1011
- print(f'\nOneOfTypes Function call: {result3}')
1012
- print(f'OneOfTypes Subscript: {result4}')
1013
- print(f'Same? {result3 == result4}')
1014
-
1015
-
1016
- if __name__ == '__main__':
1017
- unittest.main()
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Comprehensive tests for brainstate.mixin module.
18
+
19
+ This test suite covers all functionality in the mixin module including:
20
+ - Base mixin classes
21
+ - Parameter description and deferred instantiation
22
+ - Type utilities (JointTypes, OneOfTypes)
23
+ - Mode system (Mode, JointMode, Training, Batching)
24
+ - Helper utilities (hashable, not_implemented, etc.)
25
+ """
26
+
27
+ import unittest
28
+
29
+ import jax.numpy as jnp
30
+
31
+ import brainstate
32
+
33
+
34
+ class TestHashableFunction(unittest.TestCase):
35
+ """Test the hashable utility function."""
36
+
37
+ def test_hashable_primitives(self):
38
+ """Test hashable with primitive types."""
39
+ self.assertTrue(brainstate.mixin.hashable(42))
40
+ self.assertTrue(brainstate.mixin.hashable(3.14))
41
+ self.assertTrue(brainstate.mixin.hashable("string"))
42
+ self.assertTrue(brainstate.mixin.hashable(True))
43
+ self.assertTrue(brainstate.mixin.hashable(None))
44
+
45
+ def test_hashable_tuples(self):
46
+ """Test hashable with tuples."""
47
+ self.assertTrue(brainstate.mixin.hashable((1, 2, 3)))
48
+ self.assertTrue(brainstate.mixin.hashable(("a", "b")))
49
+ self.assertTrue(brainstate.mixin.hashable(()))
50
+
51
+ def test_non_hashable_types(self):
52
+ """Test non-hashable types."""
53
+ self.assertFalse(brainstate.mixin.hashable([1, 2, 3]))
54
+ self.assertFalse(brainstate.mixin.hashable({"key": "value"}))
55
+ self.assertFalse(brainstate.mixin.hashable({1, 2, 3}))
56
+ self.assertFalse(brainstate.mixin.hashable(jnp.array([1, 2, 3])))
57
+
58
+
59
+ class TestMixin(unittest.TestCase):
60
+ """Test the base Mixin class."""
61
+
62
+ def test_mixin_exists(self):
63
+ """Test that Mixin class exists."""
64
+ self.assertTrue(brainstate.mixin.Mixin)
65
+
66
+ def test_mixin_inheritance(self):
67
+ """Test creating a custom mixin."""
68
+
69
+ class LoggingMixin(brainstate.mixin.Mixin):
70
+ def log(self, message):
71
+ return f"[LOG] {message}"
72
+
73
+ class Component(LoggingMixin):
74
+ pass
75
+
76
+ comp = Component()
77
+ self.assertEqual(comp.log("test"), "[LOG] test")
78
+
79
+ def test_mixin_multiple_inheritance(self):
80
+ """Test multiple mixin inheritance."""
81
+
82
+ class MixinA(brainstate.mixin.Mixin):
83
+ def method_a(self):
84
+ return "A"
85
+
86
+ class MixinB(brainstate.mixin.Mixin):
87
+ def method_b(self):
88
+ return "B"
89
+
90
+ class Component(MixinA, MixinB):
91
+ pass
92
+
93
+ comp = Component()
94
+ self.assertEqual(comp.method_a(), "A")
95
+ self.assertEqual(comp.method_b(), "B")
96
+
97
+
98
+ class TestParamDesc(unittest.TestCase):
99
+ """Test ParamDesc mixin and ParamDescriber."""
100
+
101
+ def test_param_desc_basic(self):
102
+ """Test basic ParamDesc functionality."""
103
+
104
+ class Network(brainstate.mixin.ParamDesc):
105
+ def __init__(self, size, learning_rate=0.01):
106
+ self.size = size
107
+ self.learning_rate = learning_rate
108
+
109
+ # Test desc method exists
110
+ self.assertTrue(hasattr(Network, 'desc'))
111
+
112
+ # Create a descriptor
113
+ desc = Network.desc(size=100)
114
+ self.assertIsInstance(desc, brainstate.mixin.ParamDescriber)
115
+
116
+ def test_param_describer_instantiation(self):
117
+ """Test ParamDescriber can create instances."""
118
+
119
+ class Network(brainstate.mixin.ParamDesc):
120
+ def __init__(self, size, learning_rate=0.01):
121
+ self.size = size
122
+ self.learning_rate = learning_rate
123
+
124
+ desc = Network.desc(size=100, learning_rate=0.001)
125
+
126
+ # Create instances
127
+ net1 = desc()
128
+ self.assertEqual(net1.size, 100)
129
+ self.assertEqual(net1.learning_rate, 0.001)
130
+
131
+ # Create with overrides
132
+ net2 = desc(learning_rate=0.005)
133
+ self.assertEqual(net2.size, 100)
134
+ self.assertEqual(net2.learning_rate, 0.005)
135
+
136
+ def test_param_describer_init_method(self):
137
+ """Test ParamDescriber.init() method."""
138
+
139
+ class Model(brainstate.mixin.ParamDesc):
140
+ def __init__(self, value):
141
+ self.value = value
142
+
143
+ desc = Model.desc(value=42)
144
+ instance = desc.init()
145
+ self.assertEqual(instance.value, 42)
146
+
147
+ def test_param_describer_identifier(self):
148
+ """Test ParamDescriber identifier property."""
149
+
150
+ class Model(brainstate.mixin.ParamDesc):
151
+ def __init__(self, x, y=10):
152
+ self.x = x
153
+ self.y = y
154
+
155
+ desc = Model.desc(x=5, y=20)
156
+ identifier = desc.identifier
157
+
158
+ # Identifier should be a tuple
159
+ self.assertIsInstance(identifier, tuple)
160
+ self.assertEqual(len(identifier), 3)
161
+ self.assertEqual(identifier[0], Model)
162
+
163
+ # Identifier should be read-only
164
+ with self.assertRaises(AttributeError):
165
+ desc.identifier = "new"
166
+
167
+ def test_param_describer_class_getitem(self):
168
+ """Test ParamDescriber[Class] notation."""
169
+
170
+ class Model:
171
+ def __init__(self, value):
172
+ self.value = value
173
+
174
+ desc = brainstate.mixin.ParamDescriber[Model]
175
+ self.assertIsInstance(desc, brainstate.mixin.ParamDescriber)
176
+ self.assertEqual(desc.cls, Model)
177
+
178
+ def test_no_subclass_meta(self):
179
+ """Test that ParamDescriber cannot be subclassed."""
180
+
181
+ with self.assertRaises(TypeError):
182
+ class CustomDescriber(brainstate.mixin.ParamDescriber):
183
+ pass
184
+
185
+
186
+ class TestHashableDict(unittest.TestCase):
187
+ """Test HashableDict class."""
188
+
189
+ def test_hashable_dict_basic(self):
190
+ """Test basic HashableDict functionality."""
191
+ d = brainstate.mixin.HashableDict({"a": 1, "b": 2})
192
+ h = hash(d)
193
+ self.assertIsInstance(h, int)
194
+
195
+ def test_hashable_dict_with_arrays(self):
196
+ """Test HashableDict with non-hashable values."""
197
+ d = brainstate.mixin.HashableDict({
198
+ "array": jnp.array([1, 2, 3]),
199
+ "value": 42
200
+ })
201
+ h = hash(d)
202
+ self.assertIsInstance(h, int)
203
+
204
+ def test_hashable_dict_consistency(self):
205
+ """Test that equal dicts have equal hashes."""
206
+ d1 = brainstate.mixin.HashableDict({"a": 1, "b": 2})
207
+ d2 = brainstate.mixin.HashableDict({"b": 2, "a": 1})
208
+ self.assertEqual(hash(d1), hash(d2))
209
+
210
+ def test_hashable_dict_usable_as_key(self):
211
+ """Test that HashableDict can be used as dict key."""
212
+ d = brainstate.mixin.HashableDict({"x": 10})
213
+ cache = {d: "result"}
214
+ self.assertEqual(cache[d], "result")
215
+
216
+
217
+ class TestJointTypes(unittest.TestCase):
218
+ """Test JointTypes functionality."""
219
+
220
+ def test_joint_types_basic(self):
221
+ """Test basic JointTypes creation."""
222
+
223
+ class A:
224
+ pass
225
+
226
+ class B:
227
+ pass
228
+
229
+ JointAB = brainstate.mixin.JointTypes(A, B)
230
+ self.assertIsNotNone(JointAB)
231
+
232
+ def test_joint_types_isinstance(self):
233
+ """Test isinstance with JointTypes."""
234
+
235
+ class Serializable:
236
+ def save(self):
237
+ pass
238
+
239
+ class Visualizable:
240
+ def plot(self):
241
+ pass
242
+
243
+ Combined = brainstate.mixin.JointTypes(Serializable, Visualizable)
244
+
245
+ class Model(Serializable, Visualizable):
246
+ def save(self):
247
+ return "saved"
248
+
249
+ def plot(self):
250
+ return "plotted"
251
+
252
+ model = Model()
253
+ self.assertTrue(isinstance(model, Combined))
254
+
255
+ def test_joint_types_issubclass(self):
256
+ """Test issubclass with JointTypes."""
257
+
258
+ class A:
259
+ pass
260
+
261
+ class B:
262
+ pass
263
+
264
+ JointAB = brainstate.mixin.JointTypes(A, B)
265
+
266
+ class C(A, B):
267
+ pass
268
+
269
+ self.assertTrue(issubclass(C, JointAB))
270
+
271
+ def test_joint_types_single_type(self):
272
+ """Test JointTypes with single type returns that type."""
273
+
274
+ class A:
275
+ pass
276
+
277
+ result = brainstate.mixin.JointTypes(A)
278
+ self.assertEqual(result, A)
279
+
280
+ def test_joint_types_no_types(self):
281
+ """Test JointTypes with no types raises error."""
282
+ with self.assertRaises(TypeError):
283
+ brainstate.mixin.JointTypes()
284
+
285
+ def test_joint_types_removes_duplicates(self):
286
+ """Test that JointTypes removes duplicate types."""
287
+
288
+ class A:
289
+ pass
290
+
291
+ # Should handle duplicates gracefully
292
+ JointA = brainstate.mixin.JointTypes(A, A, A)
293
+ self.assertEqual(JointA, A)
294
+
295
+
296
+ class TestOneOfTypes(unittest.TestCase):
297
+ """Test OneOfTypes functionality."""
298
+
299
+ def test_one_of_types_basic(self):
300
+ """Test basic OneOfTypes creation."""
301
+ IntOrFloat = brainstate.mixin.OneOfTypes(int, float)
302
+ self.assertIsNotNone(IntOrFloat)
303
+
304
+ def test_one_of_types_isinstance(self):
305
+ """Test isinstance with OneOfTypes."""
306
+ NumType = brainstate.mixin.OneOfTypes(int, float)
307
+
308
+ self.assertTrue(isinstance(42, NumType))
309
+ self.assertTrue(isinstance(3.14, NumType))
310
+ self.assertFalse(isinstance("hello", NumType))
311
+
312
+ def test_one_of_types_single_type(self):
313
+ """Test OneOfTypes with single type returns that type."""
314
+ result = brainstate.mixin.OneOfTypes(int)
315
+ self.assertEqual(result, int)
316
+
317
+ def test_one_of_types_no_types(self):
318
+ """Test OneOfTypes with no types raises error."""
319
+ with self.assertRaises(TypeError):
320
+ brainstate.mixin.OneOfTypes()
321
+
322
+ def test_one_of_types_with_none(self):
323
+ """Test OneOfTypes with None for optional types."""
324
+ MaybeInt = brainstate.mixin.OneOfTypes(int, type(None))
325
+
326
+ self.assertTrue(isinstance(42, MaybeInt))
327
+ self.assertTrue(isinstance(None, MaybeInt))
328
+ self.assertFalse(isinstance("hello", MaybeInt))
329
+
330
+
331
+
332
+ class TestNotImplemented(unittest.TestCase):
333
+ """Test not_implemented decorator."""
334
+
335
+ def test_not_implemented_decorator(self):
336
+ """Test not_implemented decorator marks functions."""
337
+
338
+ @brainstate.mixin.not_implemented
339
+ def my_function():
340
+ pass
341
+
342
+ self.assertTrue(hasattr(my_function, 'not_implemented'))
343
+ self.assertTrue(my_function.not_implemented)
344
+
345
+ def test_not_implemented_raises(self):
346
+ """Test not_implemented decorator raises error when called."""
347
+
348
+ @brainstate.mixin.not_implemented
349
+ def my_function():
350
+ pass
351
+
352
+ with self.assertRaises(NotImplementedError) as cm:
353
+ my_function()
354
+
355
+ self.assertIn("my_function", str(cm.exception))
356
+
357
+
358
+ class TestMode(unittest.TestCase):
359
+ """Test Mode base class."""
360
+
361
+ def test_mode_creation(self):
362
+ """Test basic Mode creation."""
363
+ mode = brainstate.mixin.Mode()
364
+ self.assertIsNotNone(mode)
365
+
366
+ def test_mode_repr(self):
367
+ """Test Mode string representation."""
368
+ mode = brainstate.mixin.Mode()
369
+ self.assertEqual(repr(mode), "Mode")
370
+
371
+ def test_mode_equality(self):
372
+ """Test Mode equality comparison."""
373
+ mode1 = brainstate.mixin.Mode()
374
+ mode2 = brainstate.mixin.Mode()
375
+ self.assertEqual(mode1, mode2)
376
+
377
+ def test_mode_is_a(self):
378
+ """Test Mode.is_a() method."""
379
+ mode = brainstate.mixin.Mode()
380
+ self.assertTrue(mode.is_a(brainstate.mixin.Mode))
381
+ self.assertFalse(mode.is_a(brainstate.mixin.Training))
382
+
383
+ def test_mode_has(self):
384
+ """Test Mode.has() method."""
385
+ mode = brainstate.mixin.Mode()
386
+ self.assertTrue(mode.has(brainstate.mixin.Mode))
387
+ self.assertFalse(mode.has(brainstate.mixin.Training))
388
+
389
+ def test_custom_mode(self):
390
+ """Test creating custom mode."""
391
+
392
+ class CustomMode(brainstate.mixin.Mode):
393
+ def __init__(self, value):
394
+ self.value = value
395
+
396
+ mode = CustomMode(42)
397
+ self.assertEqual(mode.value, 42)
398
+ self.assertTrue(mode.has(brainstate.mixin.Mode))
399
+
400
+
401
+ class TestTraining(unittest.TestCase):
402
+ """Test Training mode."""
403
+
404
+ def test_training_creation(self):
405
+ """Test Training mode creation."""
406
+ training = brainstate.mixin.Training()
407
+ self.assertIsNotNone(training)
408
+
409
+ def test_training_is_mode(self):
410
+ """Test Training is a Mode."""
411
+ training = brainstate.mixin.Training()
412
+ self.assertTrue(training.has(brainstate.mixin.Mode))
413
+
414
+ def test_training_is_a(self):
415
+ """Test Training.is_a() method."""
416
+ training = brainstate.mixin.Training()
417
+ self.assertTrue(training.is_a(brainstate.mixin.Training))
418
+ self.assertFalse(training.is_a(brainstate.mixin.Batching))
419
+
420
+ def test_training_has(self):
421
+ """Test Training.has() method."""
422
+ training = brainstate.mixin.Training()
423
+ self.assertTrue(training.has(brainstate.mixin.Training))
424
+ self.assertFalse(training.has(brainstate.mixin.Batching))
425
+
426
+ def test_training_joint_types(self):
427
+ """Test Training with JointTypes."""
428
+ training = brainstate.mixin.Training()
429
+ self.assertTrue(training.is_a(brainstate.mixin.JointTypes(brainstate.mixin.Training)))
430
+ self.assertTrue(training.has(brainstate.mixin.JointTypes(brainstate.mixin.Training)))
431
+
432
+
433
+ class TestBatching(unittest.TestCase):
434
+ """Test Batching mode."""
435
+
436
+ def test_batching_creation(self):
437
+ """Test Batching mode creation."""
438
+ batching = brainstate.mixin.Batching()
439
+ self.assertIsNotNone(batching)
440
+
441
+ def test_batching_default_params(self):
442
+ """Test Batching default parameters."""
443
+ batching = brainstate.mixin.Batching()
444
+ self.assertEqual(batching.batch_size, 1)
445
+ self.assertEqual(batching.batch_axis, 0)
446
+
447
+ def test_batching_custom_params(self):
448
+ """Test Batching with custom parameters."""
449
+ batching = brainstate.mixin.Batching(batch_size=32, batch_axis=1)
450
+ self.assertEqual(batching.batch_size, 32)
451
+ self.assertEqual(batching.batch_axis, 1)
452
+
453
+ def test_batching_repr(self):
454
+ """Test Batching string representation."""
455
+ batching = brainstate.mixin.Batching(batch_size=64, batch_axis=0)
456
+ self.assertIn("64", repr(batching))
457
+ self.assertIn("0", repr(batching))
458
+
459
+ def test_batching_is_mode(self):
460
+ """Test Batching is a Mode."""
461
+ batching = brainstate.mixin.Batching()
462
+ self.assertTrue(batching.has(brainstate.mixin.Mode))
463
+
464
+ def test_batching_is_a(self):
465
+ """Test Batching.is_a() method."""
466
+ batching = brainstate.mixin.Batching()
467
+ self.assertTrue(batching.is_a(brainstate.mixin.Batching))
468
+ self.assertFalse(batching.is_a(brainstate.mixin.Training))
469
+
470
+ def test_batching_has(self):
471
+ """Test Batching.has() method."""
472
+ batching = brainstate.mixin.Batching()
473
+ self.assertTrue(batching.has(brainstate.mixin.Batching))
474
+ self.assertFalse(batching.has(brainstate.mixin.Training))
475
+
476
+
477
+ class TestJointMode(unittest.TestCase):
478
+ """Test JointMode functionality."""
479
+
480
+ def test_joint_mode_creation(self):
481
+ """Test JointMode creation."""
482
+ training = brainstate.mixin.Training()
483
+ batching = brainstate.mixin.Batching()
484
+ joint = brainstate.mixin.JointMode(training, batching)
485
+ self.assertIsNotNone(joint)
486
+
487
+ def test_joint_mode_repr(self):
488
+ """Test JointMode string representation."""
489
+ training = brainstate.mixin.Training()
490
+ batching = brainstate.mixin.Batching(batch_size=32)
491
+ joint = brainstate.mixin.JointMode(training, batching)
492
+
493
+ repr_str = repr(joint)
494
+ self.assertIn("JointMode", repr_str)
495
+ self.assertIn("Training", repr_str)
496
+ self.assertIn("Batching", repr_str)
497
+
498
+ def test_joint_mode_has(self):
499
+ """Test JointMode.has() method."""
500
+ training = brainstate.mixin.Training()
501
+ batching = brainstate.mixin.Batching()
502
+ joint = brainstate.mixin.JointMode(training, batching)
503
+
504
+ self.assertTrue(joint.has(brainstate.mixin.Training))
505
+ self.assertTrue(joint.has(brainstate.mixin.Batching))
506
+ self.assertTrue(joint.has(brainstate.mixin.Mode))
507
+
508
+ def test_joint_mode_is_a(self):
509
+ """Test JointMode.is_a() method."""
510
+ training = brainstate.mixin.Training()
511
+ batching = brainstate.mixin.Batching()
512
+ joint = brainstate.mixin.JointMode(training, batching)
513
+
514
+ # JointMode.is_a() works by checking if the JointTypes of the mode types
515
+ # matches the expected type. This is a complex comparison.
516
+ # For practical use, test that it correctly identifies single types
517
+ self.assertFalse(joint.is_a(brainstate.mixin.Training)) # Not just Training
518
+ self.assertFalse(joint.is_a(brainstate.mixin.Batching)) # Not just Batching
519
+
520
+ # But a single mode joint should match
521
+ single_joint = brainstate.mixin.JointMode(training)
522
+ self.assertTrue(single_joint.is_a(brainstate.mixin.Training))
523
+
524
+ def test_joint_mode_single_mode(self):
525
+ """Test JointMode with single mode."""
526
+ batching = brainstate.mixin.Batching()
527
+ joint = brainstate.mixin.JointMode(batching)
528
+
529
+ self.assertTrue(joint.has(brainstate.mixin.Batching))
530
+ self.assertTrue(joint.is_a(brainstate.mixin.Batching))
531
+
532
+ def test_joint_mode_attribute_access(self):
533
+ """Test JointMode attribute delegation."""
534
+ batching = brainstate.mixin.Batching(batch_size=32, batch_axis=1)
535
+ training = brainstate.mixin.Training()
536
+ joint = brainstate.mixin.JointMode(batching, training)
537
+
538
+ # Should access batching attributes
539
+ self.assertEqual(joint.batch_size, 32)
540
+ self.assertEqual(joint.batch_axis, 1)
541
+
542
+ def test_joint_mode_invalid_type(self):
543
+ """Test JointMode with non-Mode raises error."""
544
+ with self.assertRaises(TypeError):
545
+ brainstate.mixin.JointMode("not a mode")
546
+
547
+ def test_joint_mode_modes_attribute(self):
548
+ """Test accessing modes attribute."""
549
+ training = brainstate.mixin.Training()
550
+ batching = brainstate.mixin.Batching()
551
+ joint = brainstate.mixin.JointMode(training, batching)
552
+
553
+ self.assertEqual(len(joint.modes), 2)
554
+ self.assertIn(training, joint.modes)
555
+ self.assertIn(batching, joint.modes)
556
+
557
+ def test_joint_mode_types_attribute(self):
558
+ """Test accessing types attribute."""
559
+ training = brainstate.mixin.Training()
560
+ batching = brainstate.mixin.Batching()
561
+ joint = brainstate.mixin.JointMode(training, batching)
562
+
563
+ self.assertEqual(len(joint.types), 2)
564
+ self.assertIn(brainstate.mixin.Training, joint.types)
565
+ self.assertIn(brainstate.mixin.Batching, joint.types)
566
+
567
+
568
+ class TestIntegration(unittest.TestCase):
569
+ """Integration tests combining multiple features."""
570
+
571
+ def test_param_desc_with_modes(self):
572
+ """Test ParamDesc with Mode system."""
573
+
574
+ class Model(brainstate.mixin.ParamDesc):
575
+ def __init__(self, size, mode=None):
576
+ self.size = size
577
+ self.mode = mode if mode is not None else brainstate.mixin.Mode()
578
+
579
+ # Create descriptor with training mode
580
+ train_model_desc = Model.desc(size=100, mode=brainstate.mixin.Training())
581
+ model = train_model_desc()
582
+
583
+ self.assertEqual(model.size, 100)
584
+ self.assertTrue(model.mode.has(brainstate.mixin.Training))
585
+
586
+ def test_joint_types_with_multiple_mixins(self):
587
+ """Test JointTypes with multiple mixin classes."""
588
+
589
+ class Serializable(brainstate.mixin.Mixin):
590
+ def save(self):
591
+ return "saved"
592
+
593
+ class Trainable(brainstate.mixin.Mixin):
594
+ def train(self):
595
+ return "trained"
596
+
597
+ class Evaluable(brainstate.mixin.Mixin):
598
+ def evaluate(self):
599
+ return "evaluated"
600
+
601
+ FullModel = brainstate.mixin.JointTypes(Serializable, Trainable, Evaluable)
602
+
603
+ class MyModel(Serializable, Trainable, Evaluable):
604
+ pass
605
+
606
+ model = MyModel()
607
+ self.assertTrue(isinstance(model, FullModel))
608
+ self.assertEqual(model.save(), "saved")
609
+ self.assertEqual(model.train(), "trained")
610
+ self.assertEqual(model.evaluate(), "evaluated")
611
+
612
+ def test_complex_mode_scenario(self):
613
+ """Test complex scenario with multiple modes."""
614
+
615
+ class NeuralNetwork:
616
+ def __init__(self):
617
+ self.mode = None
618
+
619
+ def set_mode(self, mode):
620
+ self.mode = mode
621
+
622
+ def forward(self, x):
623
+ if self.mode is None:
624
+ return x
625
+
626
+ if self.mode.has(brainstate.mixin.Training):
627
+ # Add noise during training
628
+ x = x + 0.1
629
+
630
+ if self.mode.has(brainstate.mixin.Batching):
631
+ # Process in batches
632
+ batch_size = self.mode.batch_size
633
+ # Just return with batch info for testing
634
+ return x, batch_size
635
+
636
+ return x
637
+
638
+ net = NeuralNetwork()
639
+
640
+ # Test evaluation mode
641
+ result = net.forward(1.0)
642
+ self.assertEqual(result, 1.0)
643
+
644
+ # Test training mode
645
+ net.set_mode(brainstate.mixin.Training())
646
+ result = net.forward(1.0)
647
+ self.assertAlmostEqual(result, 1.1)
648
+
649
+ # Test joint mode
650
+ training = brainstate.mixin.Training()
651
+ batching = brainstate.mixin.Batching(batch_size=32)
652
+ net.set_mode(brainstate.mixin.JointMode(training, batching))
653
+
654
+ result, batch_size = net.forward(1.0)
655
+ self.assertAlmostEqual(result, 1.1)
656
+ self.assertEqual(batch_size, 32)
657
+
658
+
659
+ class TestJointTypesComprehensive(unittest.TestCase):
660
+ """Comprehensive tests for JointTypes special methods and functionality."""
661
+
662
+ def setUp(self):
663
+ """Set up test classes."""
664
+ class A:
665
+ pass
666
+
667
+ class B:
668
+ pass
669
+
670
+ class C:
671
+ pass
672
+
673
+ self.A = A
674
+ self.B = B
675
+ self.C = C
676
+
677
+ def test_repr(self):
678
+ """Test __repr__ method."""
679
+ JT = brainstate.mixin.JointTypes[self.A, self.B]
680
+ repr_str = repr(JT)
681
+ self.assertIn('JointTypes', repr_str)
682
+ self.assertIn('A', repr_str)
683
+ self.assertIn('B', repr_str)
684
+
685
+ def test_eq_same_order(self):
686
+ """Test equality with same type order."""
687
+ JT1 = brainstate.mixin.JointTypes[self.A, self.B]
688
+ JT2 = brainstate.mixin.JointTypes[self.A, self.B]
689
+ self.assertEqual(JT1, JT2)
690
+
691
+ def test_eq_different_order(self):
692
+ """Test equality with different type order."""
693
+ JT1 = brainstate.mixin.JointTypes[self.A, self.B]
694
+ JT2 = brainstate.mixin.JointTypes[self.B, self.A]
695
+ self.assertEqual(JT1, JT2)
696
+
697
+ def test_eq_different_types(self):
698
+ """Test inequality with different types."""
699
+ JT1 = brainstate.mixin.JointTypes[self.A, self.B]
700
+ JT2 = brainstate.mixin.JointTypes[self.A, self.C]
701
+ self.assertNotEqual(JT1, JT2)
702
+
703
+ def test_eq_with_non_jointtypes(self):
704
+ """Test equality with non-JointTypes object."""
705
+ JT = brainstate.mixin.JointTypes[self.A, self.B]
706
+ self.assertNotEqual(JT, "not a type")
707
+ self.assertNotEqual(JT, 42)
708
+ self.assertNotEqual(JT, self.A)
709
+
710
+ def test_hash_consistency(self):
711
+ """Test hash consistency."""
712
+ JT = brainstate.mixin.JointTypes[self.A, self.B]
713
+ hash1 = hash(JT)
714
+ hash2 = hash(JT)
715
+ self.assertEqual(hash1, hash2)
716
+
717
+ def test_hash_order_independent(self):
718
+ """Test hash is order-independent."""
719
+ JT1 = brainstate.mixin.JointTypes[self.A, self.B]
720
+ JT2 = brainstate.mixin.JointTypes[self.B, self.A]
721
+ self.assertEqual(hash(JT1), hash(JT2))
722
+
723
+ def test_hash_different_for_different_types(self):
724
+ """Test different types have different hashes."""
725
+ JT1 = brainstate.mixin.JointTypes[self.A, self.B]
726
+ JT2 = brainstate.mixin.JointTypes[self.A, self.C]
727
+ # Note: hash collision is possible but unlikely for different types
728
+ self.assertNotEqual(hash(JT1), hash(JT2))
729
+
730
+ def test_hashable_in_set(self):
731
+ """Test JointTypes can be used in sets."""
732
+ JT1 = brainstate.mixin.JointTypes[self.A, self.B]
733
+ JT2 = brainstate.mixin.JointTypes[self.B, self.A]
734
+ JT3 = brainstate.mixin.JointTypes[self.A, self.C]
735
+
736
+ type_set = {JT1, JT2, JT3}
737
+ # JT1 and JT2 are equal, so set should have 2 elements
738
+ self.assertEqual(len(type_set), 2)
739
+ self.assertIn(JT1, type_set)
740
+ self.assertIn(JT2, type_set)
741
+ self.assertIn(JT3, type_set)
742
+
743
+ def test_as_dict_key(self):
744
+ """Test JointTypes can be used as dict keys."""
745
+ JT1 = brainstate.mixin.JointTypes[self.A, self.B]
746
+ JT2 = brainstate.mixin.JointTypes[self.B, self.A]
747
+
748
+ type_dict = {JT1: "AB type"}
749
+ self.assertIn(JT1, type_dict)
750
+ # JT2 should work as key since it's equal to JT1
751
+ self.assertIn(JT2, type_dict)
752
+ self.assertEqual(type_dict[JT2], "AB type")
753
+
754
+ def test_pickle_roundtrip(self):
755
+ """Test pickling and unpickling with built-in types."""
756
+ import pickle
757
+ # Use built-in types since local classes can't be pickled
758
+ JT = brainstate.mixin.JointTypes[int, str]
759
+ pickled = pickle.dumps(JT)
760
+ unpickled = pickle.loads(pickled)
761
+ self.assertEqual(JT, unpickled)
762
+ self.assertEqual(hash(JT), hash(unpickled))
763
+
764
+ def test_pickle_preserves_isinstance(self):
765
+ """Test isinstance works after pickle with built-in types."""
766
+ import pickle
767
+
768
+ class IntStr(int):
769
+ """A class that inherits from int."""
770
+ pass
771
+
772
+ # Use built-in types for pickling
773
+ JT = brainstate.mixin.JointTypes[int, object]
774
+ pickled = pickle.dumps(JT)
775
+ unpickled = pickle.loads(pickled)
776
+
777
+ obj = 42
778
+ self.assertTrue(isinstance(obj, JT))
779
+ self.assertTrue(isinstance(obj, unpickled))
780
+
781
+ def test_multiple_types(self):
782
+ """Test JointTypes with more than 2 types."""
783
+ JT = brainstate.mixin.JointTypes[self.A, self.B, self.C]
784
+
785
+ class ABC(self.A, self.B, self.C):
786
+ pass
787
+
788
+ self.assertTrue(issubclass(ABC, JT))
789
+
790
+ class AB(self.A, self.B):
791
+ pass
792
+
793
+ self.assertFalse(issubclass(AB, JT))
794
+
795
+ def test_subscript_vs_call_syntax(self):
796
+ """Test subscript and call syntax produce equal results."""
797
+ JT_subscript = brainstate.mixin.JointTypes[self.A, self.B]
798
+ JT_call = brainstate.mixin.JointTypes(self.A, self.B)
799
+ self.assertEqual(JT_subscript, JT_call)
800
+ self.assertEqual(hash(JT_subscript), hash(JT_call))
801
+
802
+ def test_args_attribute(self):
803
+ """Test __args__ attribute contains correct types."""
804
+ JT = brainstate.mixin.JointTypes[self.A, self.B]
805
+ self.assertIn(self.A, JT.__args__)
806
+ self.assertIn(self.B, JT.__args__)
807
+ self.assertEqual(len(JT.__args__), 2)
808
+
809
+
810
+ class TestOneOfTypesComprehensive(unittest.TestCase):
811
+ """Comprehensive tests for OneOfTypes special methods and functionality."""
812
+
813
+ def setUp(self):
814
+ """Set up test classes."""
815
+ class A:
816
+ pass
817
+
818
+ class B:
819
+ pass
820
+
821
+ class C:
822
+ pass
823
+
824
+ self.A = A
825
+ self.B = B
826
+ self.C = C
827
+
828
+ def test_repr(self):
829
+ """Test __repr__ method."""
830
+ OT = brainstate.mixin.OneOfTypes[self.A, self.B]
831
+ repr_str = repr(OT)
832
+ self.assertIn('OneOfTypes', repr_str)
833
+ self.assertIn('A', repr_str)
834
+ self.assertIn('B', repr_str)
835
+
836
+ def test_eq_same_order(self):
837
+ """Test equality with same type order."""
838
+ OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
839
+ OT2 = brainstate.mixin.OneOfTypes[self.A, self.B]
840
+ self.assertEqual(OT1, OT2)
841
+
842
+ def test_eq_different_order(self):
843
+ """Test equality with different type order."""
844
+ OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
845
+ OT2 = brainstate.mixin.OneOfTypes[self.B, self.A]
846
+ self.assertEqual(OT1, OT2)
847
+
848
+ def test_eq_different_types(self):
849
+ """Test inequality with different types."""
850
+ OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
851
+ OT2 = brainstate.mixin.OneOfTypes[self.A, self.C]
852
+ self.assertNotEqual(OT1, OT2)
853
+
854
+ def test_eq_with_non_oneoftypes(self):
855
+ """Test equality with non-OneOfTypes object."""
856
+ OT = brainstate.mixin.OneOfTypes[self.A, self.B]
857
+ self.assertNotEqual(OT, "not a type")
858
+ self.assertNotEqual(OT, 42)
859
+ self.assertNotEqual(OT, self.A)
860
+
861
+ def test_hash_consistency(self):
862
+ """Test hash consistency."""
863
+ OT = brainstate.mixin.OneOfTypes[self.A, self.B]
864
+ hash1 = hash(OT)
865
+ hash2 = hash(OT)
866
+ self.assertEqual(hash1, hash2)
867
+
868
+ def test_hash_order_independent(self):
869
+ """Test hash is order-independent."""
870
+ OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
871
+ OT2 = brainstate.mixin.OneOfTypes[self.B, self.A]
872
+ self.assertEqual(hash(OT1), hash(OT2))
873
+
874
+ def test_hash_different_for_different_types(self):
875
+ """Test different types have different hashes."""
876
+ OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
877
+ OT2 = brainstate.mixin.OneOfTypes[self.A, self.C]
878
+ self.assertNotEqual(hash(OT1), hash(OT2))
879
+
880
+ def test_hashable_in_set(self):
881
+ """Test OneOfTypes can be used in sets."""
882
+ OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
883
+ OT2 = brainstate.mixin.OneOfTypes[self.B, self.A]
884
+ OT3 = brainstate.mixin.OneOfTypes[self.A, self.C]
885
+
886
+ type_set = {OT1, OT2, OT3}
887
+ # OT1 and OT2 are equal, so set should have 2 elements
888
+ self.assertEqual(len(type_set), 2)
889
+ self.assertIn(OT1, type_set)
890
+ self.assertIn(OT2, type_set)
891
+ self.assertIn(OT3, type_set)
892
+
893
+ def test_as_dict_key(self):
894
+ """Test OneOfTypes can be used as dict keys."""
895
+ OT1 = brainstate.mixin.OneOfTypes[self.A, self.B]
896
+ OT2 = brainstate.mixin.OneOfTypes[self.B, self.A]
897
+
898
+ type_dict = {OT1: "A or B type"}
899
+ self.assertIn(OT1, type_dict)
900
+ self.assertIn(OT2, type_dict)
901
+ self.assertEqual(type_dict[OT2], "A or B type")
902
+
903
+ def test_pickle_roundtrip(self):
904
+ """Test pickling and unpickling with built-in types."""
905
+ import pickle
906
+ # Use built-in types since local classes can't be pickled
907
+ OT = brainstate.mixin.OneOfTypes[int, str]
908
+ pickled = pickle.dumps(OT)
909
+ unpickled = pickle.loads(pickled)
910
+ self.assertEqual(OT, unpickled)
911
+ self.assertEqual(hash(OT), hash(unpickled))
912
+
913
+ def test_pickle_preserves_isinstance(self):
914
+ """Test isinstance works after pickle with built-in types."""
915
+ import pickle
916
+ # Use built-in types for pickling
917
+ OT = brainstate.mixin.OneOfTypes[int, str]
918
+ pickled = pickle.dumps(OT)
919
+ unpickled = pickle.loads(pickled)
920
+
921
+ obj_a = 42
922
+ obj_b = "hello"
923
+
924
+ self.assertTrue(isinstance(obj_a, OT))
925
+ self.assertTrue(isinstance(obj_a, unpickled))
926
+ self.assertTrue(isinstance(obj_b, OT))
927
+ self.assertTrue(isinstance(obj_b, unpickled))
928
+
929
+ def test_isinstance_with_any_type(self):
930
+ """Test isinstance returns True if object is instance of any type."""
931
+ OT = brainstate.mixin.OneOfTypes[self.A, self.B]
932
+
933
+ obj_a = self.A()
934
+ obj_b = self.B()
935
+ obj_c = self.C()
936
+
937
+ self.assertTrue(isinstance(obj_a, OT))
938
+ self.assertTrue(isinstance(obj_b, OT))
939
+ self.assertFalse(isinstance(obj_c, OT))
940
+
941
+ def test_issubclass_with_any_type(self):
942
+ """Test issubclass returns True if class is subclass of any type."""
943
+ OT = brainstate.mixin.OneOfTypes[self.A, self.B]
944
+
945
+ class SubA(self.A):
946
+ pass
947
+
948
+ class SubB(self.B):
949
+ pass
950
+
951
+ self.assertTrue(issubclass(SubA, OT))
952
+ self.assertTrue(issubclass(SubB, OT))
953
+ self.assertTrue(issubclass(self.A, OT))
954
+ self.assertTrue(issubclass(self.B, OT))
955
+ self.assertFalse(issubclass(self.C, OT))
956
+
957
+ def test_multiple_types(self):
958
+ """Test OneOfTypes with more than 2 types."""
959
+ OT = brainstate.mixin.OneOfTypes[self.A, self.B, self.C]
960
+
961
+ obj_a = self.A()
962
+ obj_b = self.B()
963
+ obj_c = self.C()
964
+
965
+ self.assertTrue(isinstance(obj_a, OT))
966
+ self.assertTrue(isinstance(obj_b, OT))
967
+ self.assertTrue(isinstance(obj_c, OT))
968
+
969
+ def test_subscript_vs_call_syntax(self):
970
+ """Test subscript and call syntax produce equal results."""
971
+ OT_subscript = brainstate.mixin.OneOfTypes[self.A, self.B]
972
+ OT_call = brainstate.mixin.OneOfTypes(self.A, self.B)
973
+ self.assertEqual(OT_subscript, OT_call)
974
+ self.assertEqual(hash(OT_subscript), hash(OT_call))
975
+
976
+ def test_args_attribute(self):
977
+ """Test __args__ attribute contains correct types."""
978
+ OT = brainstate.mixin.OneOfTypes[self.A, self.B]
979
+ self.assertIn(self.A, OT.__args__)
980
+ self.assertIn(self.B, OT.__args__)
981
+ self.assertEqual(len(OT.__args__), 2)
982
+
983
+ def test_with_builtin_types(self):
984
+ """Test OneOfTypes with built-in types."""
985
+ OT = brainstate.mixin.OneOfTypes[int, float, str]
986
+
987
+ self.assertTrue(isinstance(42, OT))
988
+ self.assertTrue(isinstance(3.14, OT))
989
+ self.assertTrue(isinstance("hello", OT))
990
+ self.assertFalse(isinstance([], OT))
991
+
992
+
993
+ class TestJointTy:
994
+ def test1(self):
995
+ class Potassium:
996
+ pass
997
+
998
+ class Calcium:
999
+ pass
1000
+
1001
+ # Test JointTypes
1002
+ result1 = brainstate.mixin.JointTypes(Potassium, Calcium)
1003
+ result2 = brainstate.mixin.JointTypes[Potassium, Calcium]
1004
+ print(f'Function call: {result1}')
1005
+ print(f'Subscript: {result2}')
1006
+ print(f'Same? {result1 == result2}')
1007
+
1008
+ # Test OneOfTypes
1009
+ result3 = brainstate.mixin.OneOfTypes(Potassium, Calcium)
1010
+ result4 = brainstate.mixin.OneOfTypes[Potassium, Calcium]
1011
+ print(f'\nOneOfTypes Function call: {result3}')
1012
+ print(f'OneOfTypes Subscript: {result4}')
1013
+ print(f'Same? {result3 == result4}')
1014
+
1015
+
1016
+ if __name__ == '__main__':
1017
+ unittest.main()