brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/typing_test.py CHANGED
@@ -1,780 +1,780 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """
17
- Comprehensive tests for brainstate.typing module.
18
-
19
- This test suite validates all type annotations, protocols, and type utilities
20
- provided by the typing module, ensuring they work correctly with JAX, NumPy,
21
- and BrainUnit integration.
22
- """
23
-
24
- import unittest
25
- from typing import get_type_hints, Union, Any
26
-
27
- import brainunit as u
28
- import jax
29
- import jax.numpy as jnp
30
- import jax.random as jr
31
- import numpy as np
32
-
33
- from brainstate.typing import (
34
- # Key and path types
35
- Key, PathParts, FilterLiteral, Filter,
36
-
37
- # Array types
38
- Array, ArrayLike, Shape, Size, Axes, DType, DTypeLike, SupportsDType,
39
-
40
- # PyTree types
41
- PyTree,
42
-
43
- # Random types
44
- SeedOrKey,
45
-
46
- # Utility types
47
- Missing,
48
-
49
- # Type variables
50
- K, _T, _Annotation,
51
-
52
- # Internal utilities for testing
53
- _item_to_str, _maybe_tuple_to_str, _Array
54
- )
55
-
56
-
57
- class TestKeyProtocol(unittest.TestCase):
58
- """Test the Key protocol and related path types."""
59
-
60
- def setUp(self):
61
- """Set up test fixtures."""
62
- self.string_key = "layer1"
63
- self.int_key = 42
64
- self.float_key = 3.14
65
-
66
- def test_key_protocol_string(self):
67
- """Test that strings implement the Key protocol."""
68
- self.assertIsInstance(self.string_key, Key)
69
- self.assertTrue(hasattr(self.string_key, '__hash__'))
70
- self.assertTrue(hasattr(self.string_key, '__lt__'))
71
-
72
- def test_key_protocol_int(self):
73
- """Test that integers implement the Key protocol."""
74
- self.assertIsInstance(self.int_key, Key)
75
- self.assertTrue(hasattr(self.int_key, '__hash__'))
76
- self.assertTrue(hasattr(self.int_key, '__lt__'))
77
-
78
- def test_key_ordering(self):
79
- """Test that keys can be ordered."""
80
- self.assertTrue("a" < "b")
81
- self.assertTrue(1 < 2)
82
- self.assertTrue(1.0 < 2.0)
83
-
84
- def test_custom_key_class(self):
85
- """Test custom class implementing Key protocol."""
86
-
87
- class CustomKey:
88
- def __init__(self, name: str):
89
- self.name = name
90
-
91
- def __hash__(self) -> int:
92
- return hash(self.name)
93
-
94
- def __eq__(self, other) -> bool:
95
- return isinstance(other, CustomKey) and self.name == other.name
96
-
97
- def __lt__(self, other) -> bool:
98
- return isinstance(other, CustomKey) and self.name < other.name
99
-
100
- key1 = CustomKey("first")
101
- key2 = CustomKey("second")
102
-
103
- self.assertIsInstance(key1, Key)
104
- self.assertTrue(key1 < key2)
105
- self.assertEqual(hash(key1), hash(CustomKey("first")))
106
-
107
- def test_path_parts(self):
108
- """Test PathParts type usage."""
109
- # Simple path
110
- path1: PathParts = ("model", "layers", 0, "weights")
111
- self.assertEqual(len(path1), 4)
112
- self.assertIsInstance(path1[0], str)
113
- self.assertIsInstance(path1[2], int)
114
-
115
- # Empty path
116
- path2: PathParts = ()
117
- self.assertEqual(len(path2), 0)
118
-
119
- # Mixed types path
120
- path3: PathParts = ("root", 1, "sub", 2.5)
121
- self.assertEqual(len(path3), 4)
122
-
123
- def test_predicate_functions(self):
124
- """Test Predicate function type."""
125
-
126
- def is_weight_matrix(path: PathParts, value: Any) -> bool:
127
- return len(path) > 0 and "weight" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 2
128
-
129
- def is_bias_vector(path: PathParts, value: Any) -> bool:
130
- return len(path) > 0 and "bias" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 1
131
-
132
- # Test with mock data
133
- weight_path: PathParts = ("layer", "weight")
134
- bias_path: PathParts = ("layer", "bias")
135
-
136
- weight_matrix = np.random.randn(10, 5)
137
- bias_vector = np.random.randn(5)
138
-
139
- self.assertTrue(is_weight_matrix(weight_path, weight_matrix))
140
- self.assertFalse(is_weight_matrix(bias_path, bias_vector))
141
- self.assertTrue(is_bias_vector(bias_path, bias_vector))
142
- self.assertFalse(is_bias_vector(weight_path, weight_matrix))
143
-
144
- def test_filter_types(self):
145
- """Test various filter type combinations."""
146
- # FilterLiteral types
147
- type_filter: FilterLiteral = float
148
- string_filter: FilterLiteral = "weight"
149
- predicate_filter: FilterLiteral = lambda path, x: hasattr(x, 'ndim')
150
- bool_filter: FilterLiteral = True
151
- ellipsis_filter: FilterLiteral = ...
152
- none_filter: FilterLiteral = None
153
-
154
- # Combined filters
155
- tuple_filter: Filter = (float, "weight")
156
- list_filter: Filter = [int, float, "bias"]
157
- nested_filter: Filter = [
158
- ("weight", lambda p, x: hasattr(x, 'ndim') and x.ndim == 2),
159
- ("bias", lambda p, x: hasattr(x, 'ndim') and x.ndim == 1),
160
- ]
161
-
162
- # Verify types are correctly assigned
163
- self.assertIsInstance(type_filter, type)
164
- self.assertIsInstance(string_filter, str)
165
- self.assertTrue(callable(predicate_filter))
166
- self.assertIsInstance(bool_filter, bool)
167
- self.assertEqual(ellipsis_filter, ...)
168
- self.assertIsNone(none_filter)
169
-
170
-
171
- class TestArrayAnnotations(unittest.TestCase):
172
- """Test Array type annotations and related utilities."""
173
-
174
- def test_array_basic_annotation(self):
175
- """Test basic Array type annotation."""
176
-
177
- def process_array(x: Array) -> Array:
178
- return x * 2
179
-
180
- # Check that function can be called with various array types
181
- jax_array = jnp.array([1, 2, 3])
182
- numpy_array = np.array([1, 2, 3])
183
-
184
- result1 = process_array(jax_array)
185
- result2 = process_array(numpy_array)
186
-
187
- self.assertIsInstance(result1, jax.Array)
188
- self.assertIsInstance(result2, np.ndarray)
189
-
190
- def test_array_shape_annotation(self):
191
- """Test Array with shape annotations."""
192
-
193
- def matrix_multiply(a: Array["m, n"], b: Array["n, k"]) -> Array["m, k"]:
194
- return a @ b
195
-
196
- # Test with compatible shapes
197
- key = jax.random.PRNGKey(42)
198
- key1, key2 = jax.random.split(key)
199
- a = jax.random.normal(key1, (3, 4))
200
- b = jax.random.normal(key2, (4, 5))
201
- result = matrix_multiply(a, b)
202
-
203
- self.assertEqual(result.shape, (3, 5))
204
-
205
- def test_array_class_getitem(self):
206
- """Test Array.__class_getitem__ functionality."""
207
- # Test shape annotation creation
208
- shaped_array = Array["batch, features"]
209
- self.assertIsNotNone(shaped_array)
210
- self.assertTrue(hasattr(shaped_array, '__origin__'))
211
-
212
- # Test complex shape annotation
213
- complex_array = Array["batch, seq_len, d_model"]
214
- self.assertIsNotNone(complex_array)
215
-
216
- # Test with ellipsis
217
- flexible_array = Array["batch, ..."]
218
- self.assertIsNotNone(flexible_array)
219
-
220
- def test_item_to_str_function(self):
221
- """Test _item_to_str utility function."""
222
- # String item
223
- self.assertEqual(_item_to_str("batch"), "'batch'")
224
-
225
- # Type item
226
- self.assertEqual(_item_to_str(float), "float")
227
-
228
- # Ellipsis item
229
- self.assertEqual(_item_to_str(...), "...")
230
-
231
- # Slice item
232
- slice_item = slice("start", "stop")
233
- expected = "'start': 'stop'"
234
- self.assertEqual(_item_to_str(slice_item), expected)
235
-
236
- # Slice with step should raise NotImplementedError
237
- with self.assertRaises(NotImplementedError):
238
- _item_to_str(slice("start", "stop", "step"))
239
-
240
- def test_maybe_tuple_to_str_function(self):
241
- """Test _maybe_tuple_to_str utility function."""
242
- # Single item
243
- self.assertEqual(_maybe_tuple_to_str("single"), "'single'")
244
-
245
- # Empty tuple
246
- self.assertEqual(_maybe_tuple_to_str(()), "()")
247
-
248
- # Non-empty tuple
249
- tuple_item = ("batch", "features")
250
- expected = "'batch', 'features'"
251
- self.assertEqual(_maybe_tuple_to_str(tuple_item), expected)
252
-
253
- def test_array_module_setting(self):
254
- """Test that Array has correct module for display."""
255
- self.assertEqual(Array.__module__, "builtins")
256
-
257
-
258
- class TestShapeAndSizeTypes(unittest.TestCase):
259
- """Test shape, size, and axes type annotations."""
260
-
261
- def test_size_type_variants(self):
262
- """Test different Size type variants."""
263
- # Single integer
264
- size1: Size = 10
265
- self.assertIsInstance(size1, int)
266
-
267
- # Tuple of integers
268
- size2: Size = (3, 4, 5)
269
- self.assertIsInstance(size2, tuple)
270
- self.assertTrue(all(isinstance(x, int) for x in size2))
271
-
272
- # NumPy integers
273
- size3: Size = np.int32(8)
274
- self.assertIsInstance(size3, np.integer)
275
-
276
- # Sequence with mixed NumPy types
277
- size4: Size = [np.int64(2), 3, np.int32(4)]
278
- self.assertIsInstance(size4, list)
279
-
280
- def test_shape_type(self):
281
- """Test Shape type usage."""
282
- # 2D shape
283
- matrix_shape: Shape = (10, 20)
284
- self.assertEqual(len(matrix_shape), 2)
285
- self.assertTrue(all(isinstance(x, int) for x in matrix_shape))
286
-
287
- # 3D shape
288
- tensor_shape: Shape = (5, 10, 15)
289
- self.assertEqual(len(tensor_shape), 3)
290
-
291
- # 1D shape (still a sequence)
292
- vector_shape: Shape = (100,)
293
- self.assertEqual(len(vector_shape), 1)
294
-
295
- def test_axes_type(self):
296
- """Test Axes type variants."""
297
- # Single axis
298
- axis1: Axes = 0
299
- self.assertIsInstance(axis1, int)
300
-
301
- # Multiple axes
302
- axis2: Axes = (0, 2)
303
- self.assertIsInstance(axis2, tuple)
304
- self.assertTrue(all(isinstance(x, int) for x in axis2))
305
-
306
- # List of axes
307
- axis3: Axes = [1, 3, 4]
308
- self.assertIsInstance(axis3, list)
309
-
310
- def test_shape_operations(self):
311
- """Test operations using shape types."""
312
-
313
- def create_zeros(shape: Shape) -> jax.Array:
314
- return jnp.zeros(shape)
315
-
316
- def sum_along_axes(array: ArrayLike, axes: Axes) -> jax.Array:
317
- return jnp.sum(array, axis=axes)
318
-
319
- # Test shape creation
320
- arr = create_zeros((3, 4))
321
- self.assertEqual(arr.shape, (3, 4))
322
-
323
- # Test axes operations
324
- test_array = jnp.ones((2, 3, 4))
325
- result1 = sum_along_axes(test_array, 0)
326
- self.assertEqual(result1.shape, (3, 4))
327
-
328
- result2 = sum_along_axes(test_array, (0, 2))
329
- self.assertEqual(result2.shape, (3,))
330
-
331
-
332
- class TestArrayLikeAndDType(unittest.TestCase):
333
- """Test ArrayLike and dtype-related types."""
334
-
335
- def test_arraylike_variants(self):
336
- """Test different ArrayLike type variants."""
337
-
338
- def process_data(data: ArrayLike) -> jax.Array:
339
- return jnp.asarray(data)
340
-
341
- # JAX array
342
- jax_array = jnp.array([1, 2, 3])
343
- result1 = process_data(jax_array)
344
- self.assertIsInstance(result1, jax.Array)
345
-
346
- # NumPy array
347
- numpy_array = np.array([1, 2, 3])
348
- result2 = process_data(numpy_array)
349
- self.assertIsInstance(result2, jax.Array)
350
-
351
- # Python scalars
352
- result3 = process_data(42)
353
- self.assertIsInstance(result3, jax.Array)
354
- self.assertEqual(result3.shape, ())
355
-
356
- result4 = process_data(3.14)
357
- self.assertIsInstance(result4, jax.Array)
358
-
359
- result5 = process_data(True)
360
- self.assertIsInstance(result5, jax.Array)
361
-
362
- result6 = process_data(1 + 2j)
363
- self.assertIsInstance(result6, jax.Array)
364
-
365
- # NumPy scalars
366
- result7 = process_data(np.float32(2.5))
367
- self.assertIsInstance(result7, jax.Array)
368
-
369
- result8 = process_data(np.bool_(False))
370
- self.assertIsInstance(result8, jax.Array)
371
-
372
- # BrainUnit quantities (if available)
373
- try:
374
- quantity = 1.5 * u.second
375
- # Convert to plain array for processing
376
- result9 = process_data(quantity.mantissa)
377
- self.assertIsInstance(result9, jax.Array)
378
- except (AttributeError, TypeError):
379
- # Skip if BrainUnit quantities not properly set up
380
- pass
381
-
382
- def test_dtype_variants(self):
383
- """Test DType and DTypeLike variants."""
384
-
385
- def cast_array(array: ArrayLike, dtype: DTypeLike) -> jax.Array:
386
- return jnp.asarray(array, dtype=dtype)
387
-
388
- test_data = [1, 2, 3]
389
-
390
- # String dtype
391
- result1 = cast_array(test_data, 'float32')
392
- self.assertEqual(result1.dtype, jnp.float32)
393
-
394
- # NumPy type
395
- result2 = cast_array(test_data, np.float32)
396
- self.assertEqual(result2.dtype, jnp.float32)
397
-
398
- # Python type
399
- result3 = cast_array(test_data, float)
400
- self.assertTrue(jnp.issubdtype(result3.dtype, jnp.floating))
401
-
402
- # NumPy dtype object
403
- result4 = cast_array(test_data, np.dtype('int32'))
404
- self.assertEqual(result4.dtype, jnp.int32)
405
-
406
- def test_supports_dtype_protocol(self):
407
- """Test SupportsDType protocol."""
408
-
409
- def get_dtype(obj: SupportsDType) -> DType:
410
- return obj.dtype
411
-
412
- # Test with arrays
413
- arr = jnp.array([1.0, 2.0])
414
- dtype = get_dtype(arr)
415
- self.assertIsInstance(dtype, np.dtype)
416
-
417
- # Test with NumPy arrays
418
- np_arr = np.array([1, 2], dtype=np.int64)
419
- dtype2 = get_dtype(np_arr)
420
- self.assertEqual(dtype2, np.int64)
421
-
422
- def test_dtype_alias(self):
423
- """Test DType alias."""
424
-
425
- def create_array(shape: Shape, dtype: DType) -> jax.Array:
426
- return jnp.zeros(shape, dtype=dtype)
427
-
428
- arr = create_array((3, 4), np.float32)
429
- self.assertEqual(arr.shape, (3, 4))
430
- self.assertEqual(arr.dtype, jnp.float32)
431
-
432
-
433
- class TestPyTreeTypes(unittest.TestCase):
434
- """Test PyTree type annotations."""
435
-
436
- def test_pytree_basic_usage(self):
437
- """Test basic PyTree type usage."""
438
-
439
- def tree_function(tree: PyTree[float]) -> PyTree[float]:
440
- return jax.tree_util.tree_map(lambda x: x * 2, tree)
441
-
442
- # Test with different PyTree structures
443
- tree1 = {"a": 1.0, "b": 2.0}
444
- result1 = tree_function(tree1)
445
- self.assertAlmostEqual(result1["a"], 2.0)
446
- self.assertAlmostEqual(result1["b"], 4.0)
447
-
448
- tree2 = [1.0, 2.0, 3.0]
449
- result2 = tree_function(tree2)
450
- expected = [2.0, 4.0, 6.0]
451
- for i, (actual, expect) in enumerate(zip(result2, expected)):
452
- self.assertAlmostEqual(actual, expect)
453
-
454
- def test_pytree_with_structure(self):
455
- """Test PyTree with structure annotations."""
456
-
457
- def structured_function(tree: PyTree[float, "T"]) -> PyTree[float, "T"]:
458
- return jax.tree_util.tree_map(lambda x: x + 1, tree)
459
-
460
- # Test that function works with various structures
461
- tree = {"weights": 1.0, "bias": 2.0}
462
- result = structured_function(tree)
463
- self.assertAlmostEqual(result["weights"], 2.0)
464
- self.assertAlmostEqual(result["bias"], 3.0)
465
-
466
- def test_pytree_instantiation_error(self):
467
- """Test that PyTree cannot be instantiated."""
468
- with self.assertRaises(RuntimeError):
469
- PyTree()
470
-
471
- def test_pytree_subscripting(self):
472
- """Test PyTree subscripting behavior."""
473
- # Single type parameter
474
- pytree_type = PyTree[float]
475
- self.assertTrue(hasattr(pytree_type, 'leaftype'))
476
- self.assertEqual(pytree_type.leaftype, float)
477
-
478
- # Type and structure parameters
479
- pytree_structured = PyTree[int, "T"]
480
- self.assertTrue(hasattr(pytree_structured, 'leaftype'))
481
- self.assertTrue(hasattr(pytree_structured, 'structure'))
482
- self.assertEqual(pytree_structured.leaftype, int)
483
- self.assertEqual(pytree_structured.structure, "T")
484
-
485
- def test_pytree_structure_validation(self):
486
- """Test PyTree structure validation."""
487
- # Valid structure names
488
- valid_structures = ["T", "S T", "... T", "T ...", "foo bar"]
489
- for structure in valid_structures:
490
- PyTree[float, structure]
491
-
492
- # Invalid structures
493
- with self.assertRaises(ValueError):
494
- PyTree[float, ""] # Empty string
495
-
496
- with self.assertRaises(ValueError):
497
- PyTree[float, "invalid-identifier"] # Invalid identifier
498
-
499
- with self.assertRaises(ValueError):
500
- PyTree[float, "123abc"] # Starts with number
501
-
502
- def test_pytree_tuple_length_validation(self):
503
- """Test PyTree tuple parameter validation."""
504
- # Valid 2-tuple
505
- PyTree[float, "T"]
506
-
507
- # Invalid tuple lengths
508
- with self.assertRaises(ValueError):
509
- PyTree[float, "T", "extra"] # 3-tuple
510
-
511
- with self.assertRaises(ValueError):
512
- PyTree[float,] # 1-tuple with trailing comma would be (float,)
513
-
514
-
515
- class TestRandomTypes(unittest.TestCase):
516
- """Test random number generation types."""
517
-
518
- def test_seed_or_key_variants(self):
519
- """Test SeedOrKey type variants."""
520
-
521
- def generate_random(key: SeedOrKey, shape: Shape) -> jax.Array:
522
- if isinstance(key, int):
523
- key = jr.PRNGKey(key)
524
- return jr.normal(key, shape)
525
-
526
- # Integer seed
527
- result1 = generate_random(42, (3, 4))
528
- self.assertEqual(result1.shape, (3, 4))
529
-
530
- # JAX PRNG key
531
- jax_key = jr.PRNGKey(123)
532
- result2 = generate_random(jax_key, (5,))
533
- self.assertEqual(result2.shape, (5,))
534
-
535
- # NumPy array key
536
- np_key = np.array([1, 2], dtype=np.uint32)
537
- result3 = generate_random(np_key, (2, 2))
538
- self.assertEqual(result3.shape, (2, 2))
539
-
540
- def test_reproducibility_with_seeds(self):
541
- """Test that same seeds produce same results."""
542
-
543
- def generate_data(seed: SeedOrKey) -> jax.Array:
544
- if isinstance(seed, int):
545
- key = jr.PRNGKey(seed)
546
- else:
547
- key = seed
548
- return jr.normal(key, (5,))
549
-
550
- # Same integer seeds
551
- result1 = generate_data(42)
552
- result2 = generate_data(42)
553
- np.testing.assert_array_equal(result1, result2)
554
-
555
- # Same JAX keys
556
- key = jr.PRNGKey(999)
557
- result3 = generate_data(key)
558
- result4 = generate_data(key)
559
- np.testing.assert_array_equal(result3, result4)
560
-
561
-
562
- class TestUtilityTypes(unittest.TestCase):
563
- """Test utility types and edge cases."""
564
-
565
- def test_missing_sentinel(self):
566
- """Test Missing sentinel class."""
567
- _MISSING = Missing()
568
-
569
- def function_with_optional_param(value: Union[int, None, Missing] = _MISSING):
570
- if value is _MISSING:
571
- return "no_value"
572
- elif value is None:
573
- return "explicit_none"
574
- else:
575
- return f"value_{value}"
576
-
577
- # Test different call patterns
578
- self.assertEqual(function_with_optional_param(), "no_value")
579
- self.assertEqual(function_with_optional_param(None), "explicit_none")
580
- self.assertEqual(function_with_optional_param(42), "value_42")
581
-
582
- # Test that different Missing instances are distinct objects
583
- missing1 = Missing()
584
- missing2 = Missing()
585
- self.assertIsNot(missing1, missing2) # Different instances
586
- # Note: Missing doesn't define __eq__, so != comparison uses identity
587
-
588
- def test_type_variables(self):
589
- """Test type variables are properly defined."""
590
- # Test that type variables exist
591
- self.assertIsNotNone(K)
592
- self.assertIsNotNone(_T)
593
- self.assertIsNotNone(_Annotation)
594
-
595
- # Test that they are TypeVar instances
596
- from typing import TypeVar
597
- self.assertIsInstance(K, TypeVar)
598
- self.assertIsInstance(_T, TypeVar)
599
- self.assertIsInstance(_Annotation, TypeVar)
600
-
601
- def test_internal_array_type(self):
602
- """Test internal _Array type."""
603
- # Test that _Array exists and has proper module
604
- self.assertIsNotNone(_Array)
605
- self.assertEqual(_Array.__module__, "builtins")
606
-
607
- # Test that it can be parameterized
608
- parameterized = _Array[str]
609
- self.assertIsNotNone(parameterized)
610
-
611
-
612
- class TestRealWorldUsagePattterns(unittest.TestCase):
613
- """Test real-world usage patterns and integration."""
614
-
615
- def test_neural_network_typing(self):
616
- """Test typing patterns common in neural networks."""
617
-
618
- def linear_layer(
619
- x: Array["batch, in_features"],
620
- weight: Array["out_features, in_features"],
621
- bias: Array["out_features"]
622
- ) -> Array["batch, out_features"]:
623
- return x @ weight.T + bias
624
-
625
- # Test with actual arrays
626
- batch_size, in_features, out_features = 32, 128, 64
627
- key = jr.PRNGKey(42)
628
- key1, key2, key3 = jr.split(key, 3)
629
- x = jr.normal(key1, (batch_size, in_features))
630
- weight = jr.normal(key2, (out_features, in_features))
631
- bias = jr.normal(key3, (out_features,))
632
-
633
- result = linear_layer(x, weight, bias)
634
- self.assertEqual(result.shape, (batch_size, out_features))
635
-
636
- def test_pytree_parameter_filtering(self):
637
- """Test PyTree filtering patterns."""
638
-
639
- def extract_weights(params: PyTree[ArrayLike]) -> PyTree[ArrayLike]:
640
- # Mock filtering - in real code this would use jax.tree_util
641
- return jax.tree_util.tree_map(lambda x: x, params)
642
-
643
- # Test with parameter structure
644
- params = {
645
- "layer1": {"weight": jnp.ones((10, 5)), "bias": jnp.zeros(10)},
646
- "layer2": {"weight": jnp.ones((5, 3)), "bias": jnp.zeros(5)}
647
- }
648
-
649
- result = extract_weights(params)
650
- self.assertIsInstance(result, dict)
651
- self.assertIn("layer1", result)
652
- self.assertIn("layer2", result)
653
-
654
- def test_mixed_type_operations(self):
655
- """Test operations mixing different typed inputs."""
656
-
657
- def process_mixed_data(
658
- arrays: ArrayLike,
659
- shape: Shape,
660
- dtype: DTypeLike,
661
- seed: SeedOrKey
662
- ) -> jax.Array:
663
- # Convert inputs
664
- data = jnp.asarray(arrays, dtype=dtype)
665
- key = jr.PRNGKey(seed) if isinstance(seed, int) else seed
666
-
667
- # Generate noise and add to data
668
- noise = jr.normal(key, shape) * 0.1
669
- return data.reshape(shape) + noise
670
-
671
- # Test with mixed inputs
672
- result = process_mixed_data(
673
- arrays=[1, 2, 3, 4],
674
- shape=(2, 2),
675
- dtype='float32',
676
- seed=42
677
- )
678
-
679
- self.assertEqual(result.shape, (2, 2))
680
- self.assertEqual(result.dtype, jnp.float32)
681
-
682
- def test_scientific_computing_pattern(self):
683
- """Test scientific computing usage patterns."""
684
-
685
- def numerical_integration(
686
- func: callable,
687
- bounds: ArrayLike,
688
- n_points: Size,
689
- dtype: DTypeLike = jnp.float32 # Use float32 for JAX compatibility
690
- ) -> jax.Array:
691
- # Mock numerical integration
692
- x = jnp.linspace(bounds[0], bounds[1], n_points, dtype=dtype)
693
- y = jax.vmap(func)(x)
694
- dx = (bounds[1] - bounds[0]) / n_points
695
- return jnp.sum(y) * dx
696
-
697
- # Test with simple bounds (skip units for simplicity)
698
- bounds_array = jnp.array([0.0, 1.0])
699
-
700
- result = numerical_integration(
701
- lambda t: t ** 2,
702
- bounds_array,
703
- 1000,
704
- jnp.float32
705
- )
706
-
707
- self.assertIsInstance(result, jax.Array)
708
- self.assertEqual(result.dtype, jnp.float32)
709
-
710
-
711
- class TestTypeHintCompatibility(unittest.TestCase):
712
- """Test compatibility with Python's typing system."""
713
-
714
- def test_get_type_hints(self):
715
- """Test that type hints can be retrieved from annotated functions."""
716
-
717
- def annotated_function(
718
- arr: ArrayLike,
719
- shape: Shape,
720
- dtype: DTypeLike
721
- ) -> jax.Array:
722
- return jnp.zeros(shape, dtype=dtype)
723
-
724
- hints = get_type_hints(annotated_function)
725
-
726
- # Check that hints are captured
727
- self.assertIn('arr', hints)
728
- self.assertIn('shape', hints)
729
- self.assertIn('dtype', hints)
730
- self.assertIn('return', hints)
731
-
732
- def test_isinstance_checks(self):
733
- """Test isinstance checks with protocol types."""
734
- # Test Key protocol
735
- self.assertIsInstance("string", Key)
736
- self.assertIsInstance(42, Key)
737
- self.assertIsInstance(3.14, Key)
738
-
739
- # Test SupportsDType protocol (check for dtype attribute)
740
- arr = jnp.array([1, 2, 3])
741
- self.assertTrue(hasattr(arr, 'dtype'))
742
-
743
- np_arr = np.array([1, 2, 3])
744
- self.assertTrue(hasattr(np_arr, 'dtype'))
745
-
746
- # Test that objects without dtype don't have the attribute
747
- self.assertFalse(hasattr("string", 'dtype'))
748
-
749
- def test_module_imports(self):
750
- """Test that all types can be imported correctly."""
751
- from brainstate.typing import (
752
- Key, PathParts, Predicate, Filter, Array, ArrayLike,
753
- Shape, Size, Axes, PyTree, SeedOrKey, DType, DTypeLike,
754
- Missing
755
- )
756
-
757
- # Verify all imports succeeded
758
- types_to_check = [
759
- Key, PathParts, Predicate, Filter, Array, ArrayLike,
760
- Shape, Size, Axes, PyTree, SeedOrKey, DType, DTypeLike,
761
- Missing
762
- ]
763
-
764
- for type_obj in types_to_check:
765
- self.assertIsNotNone(type_obj)
766
-
767
- def test_documentation_strings(self):
768
- """Test that types have proper documentation."""
769
- documented_types = [
770
- Key, Array, PyTree, Missing, SupportsDType
771
- ]
772
-
773
- for type_obj in documented_types:
774
- self.assertIsNotNone(type_obj.__doc__)
775
- self.assertGreater(len(type_obj.__doc__), 10) # Has substantial documentation
776
-
777
-
778
- if __name__ == '__main__':
779
- # Run tests with verbose output
780
- unittest.main(verbosity=2)
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Comprehensive tests for brainstate.typing module.
18
+
19
+ This test suite validates all type annotations, protocols, and type utilities
20
+ provided by the typing module, ensuring they work correctly with JAX, NumPy,
21
+ and BrainUnit integration.
22
+ """
23
+
24
+ import unittest
25
+ from typing import get_type_hints, Union, Any
26
+
27
+ import brainunit as u
28
+ import jax
29
+ import jax.numpy as jnp
30
+ import jax.random as jr
31
+ import numpy as np
32
+
33
+ from brainstate.typing import (
34
+ # Key and path types
35
+ Key, PathParts, FilterLiteral, Filter,
36
+
37
+ # Array types
38
+ Array, ArrayLike, Shape, Size, Axes, DType, DTypeLike, SupportsDType,
39
+
40
+ # PyTree types
41
+ PyTree,
42
+
43
+ # Random types
44
+ SeedOrKey,
45
+
46
+ # Utility types
47
+ Missing,
48
+
49
+ # Type variables
50
+ K, _T, _Annotation,
51
+
52
+ # Internal utilities for testing
53
+ _item_to_str, _maybe_tuple_to_str, _Array
54
+ )
55
+
56
+
57
+ class TestKeyProtocol(unittest.TestCase):
58
+ """Test the Key protocol and related path types."""
59
+
60
+ def setUp(self):
61
+ """Set up test fixtures."""
62
+ self.string_key = "layer1"
63
+ self.int_key = 42
64
+ self.float_key = 3.14
65
+
66
+ def test_key_protocol_string(self):
67
+ """Test that strings implement the Key protocol."""
68
+ self.assertIsInstance(self.string_key, Key)
69
+ self.assertTrue(hasattr(self.string_key, '__hash__'))
70
+ self.assertTrue(hasattr(self.string_key, '__lt__'))
71
+
72
+ def test_key_protocol_int(self):
73
+ """Test that integers implement the Key protocol."""
74
+ self.assertIsInstance(self.int_key, Key)
75
+ self.assertTrue(hasattr(self.int_key, '__hash__'))
76
+ self.assertTrue(hasattr(self.int_key, '__lt__'))
77
+
78
+ def test_key_ordering(self):
79
+ """Test that keys can be ordered."""
80
+ self.assertTrue("a" < "b")
81
+ self.assertTrue(1 < 2)
82
+ self.assertTrue(1.0 < 2.0)
83
+
84
+ def test_custom_key_class(self):
85
+ """Test custom class implementing Key protocol."""
86
+
87
+ class CustomKey:
88
+ def __init__(self, name: str):
89
+ self.name = name
90
+
91
+ def __hash__(self) -> int:
92
+ return hash(self.name)
93
+
94
+ def __eq__(self, other) -> bool:
95
+ return isinstance(other, CustomKey) and self.name == other.name
96
+
97
+ def __lt__(self, other) -> bool:
98
+ return isinstance(other, CustomKey) and self.name < other.name
99
+
100
+ key1 = CustomKey("first")
101
+ key2 = CustomKey("second")
102
+
103
+ self.assertIsInstance(key1, Key)
104
+ self.assertTrue(key1 < key2)
105
+ self.assertEqual(hash(key1), hash(CustomKey("first")))
106
+
107
+ def test_path_parts(self):
108
+ """Test PathParts type usage."""
109
+ # Simple path
110
+ path1: PathParts = ("model", "layers", 0, "weights")
111
+ self.assertEqual(len(path1), 4)
112
+ self.assertIsInstance(path1[0], str)
113
+ self.assertIsInstance(path1[2], int)
114
+
115
+ # Empty path
116
+ path2: PathParts = ()
117
+ self.assertEqual(len(path2), 0)
118
+
119
+ # Mixed types path
120
+ path3: PathParts = ("root", 1, "sub", 2.5)
121
+ self.assertEqual(len(path3), 4)
122
+
123
+ def test_predicate_functions(self):
124
+ """Test Predicate function type."""
125
+
126
+ def is_weight_matrix(path: PathParts, value: Any) -> bool:
127
+ return len(path) > 0 and "weight" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 2
128
+
129
+ def is_bias_vector(path: PathParts, value: Any) -> bool:
130
+ return len(path) > 0 and "bias" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 1
131
+
132
+ # Test with mock data
133
+ weight_path: PathParts = ("layer", "weight")
134
+ bias_path: PathParts = ("layer", "bias")
135
+
136
+ weight_matrix = np.random.randn(10, 5)
137
+ bias_vector = np.random.randn(5)
138
+
139
+ self.assertTrue(is_weight_matrix(weight_path, weight_matrix))
140
+ self.assertFalse(is_weight_matrix(bias_path, bias_vector))
141
+ self.assertTrue(is_bias_vector(bias_path, bias_vector))
142
+ self.assertFalse(is_bias_vector(weight_path, weight_matrix))
143
+
144
+ def test_filter_types(self):
145
+ """Test various filter type combinations."""
146
+ # FilterLiteral types
147
+ type_filter: FilterLiteral = float
148
+ string_filter: FilterLiteral = "weight"
149
+ predicate_filter: FilterLiteral = lambda path, x: hasattr(x, 'ndim')
150
+ bool_filter: FilterLiteral = True
151
+ ellipsis_filter: FilterLiteral = ...
152
+ none_filter: FilterLiteral = None
153
+
154
+ # Combined filters
155
+ tuple_filter: Filter = (float, "weight")
156
+ list_filter: Filter = [int, float, "bias"]
157
+ nested_filter: Filter = [
158
+ ("weight", lambda p, x: hasattr(x, 'ndim') and x.ndim == 2),
159
+ ("bias", lambda p, x: hasattr(x, 'ndim') and x.ndim == 1),
160
+ ]
161
+
162
+ # Verify types are correctly assigned
163
+ self.assertIsInstance(type_filter, type)
164
+ self.assertIsInstance(string_filter, str)
165
+ self.assertTrue(callable(predicate_filter))
166
+ self.assertIsInstance(bool_filter, bool)
167
+ self.assertEqual(ellipsis_filter, ...)
168
+ self.assertIsNone(none_filter)
169
+
170
+
171
+ class TestArrayAnnotations(unittest.TestCase):
172
+ """Test Array type annotations and related utilities."""
173
+
174
+ def test_array_basic_annotation(self):
175
+ """Test basic Array type annotation."""
176
+
177
+ def process_array(x: Array) -> Array:
178
+ return x * 2
179
+
180
+ # Check that function can be called with various array types
181
+ jax_array = jnp.array([1, 2, 3])
182
+ numpy_array = np.array([1, 2, 3])
183
+
184
+ result1 = process_array(jax_array)
185
+ result2 = process_array(numpy_array)
186
+
187
+ self.assertIsInstance(result1, jax.Array)
188
+ self.assertIsInstance(result2, np.ndarray)
189
+
190
+ def test_array_shape_annotation(self):
191
+ """Test Array with shape annotations."""
192
+
193
+ def matrix_multiply(a: Array["m, n"], b: Array["n, k"]) -> Array["m, k"]:
194
+ return a @ b
195
+
196
+ # Test with compatible shapes
197
+ key = jax.random.PRNGKey(42)
198
+ key1, key2 = jax.random.split(key)
199
+ a = jax.random.normal(key1, (3, 4))
200
+ b = jax.random.normal(key2, (4, 5))
201
+ result = matrix_multiply(a, b)
202
+
203
+ self.assertEqual(result.shape, (3, 5))
204
+
205
+ def test_array_class_getitem(self):
206
+ """Test Array.__class_getitem__ functionality."""
207
+ # Test shape annotation creation
208
+ shaped_array = Array["batch, features"]
209
+ self.assertIsNotNone(shaped_array)
210
+ self.assertTrue(hasattr(shaped_array, '__origin__'))
211
+
212
+ # Test complex shape annotation
213
+ complex_array = Array["batch, seq_len, d_model"]
214
+ self.assertIsNotNone(complex_array)
215
+
216
+ # Test with ellipsis
217
+ flexible_array = Array["batch, ..."]
218
+ self.assertIsNotNone(flexible_array)
219
+
220
+ def test_item_to_str_function(self):
221
+ """Test _item_to_str utility function."""
222
+ # String item
223
+ self.assertEqual(_item_to_str("batch"), "'batch'")
224
+
225
+ # Type item
226
+ self.assertEqual(_item_to_str(float), "float")
227
+
228
+ # Ellipsis item
229
+ self.assertEqual(_item_to_str(...), "...")
230
+
231
+ # Slice item
232
+ slice_item = slice("start", "stop")
233
+ expected = "'start': 'stop'"
234
+ self.assertEqual(_item_to_str(slice_item), expected)
235
+
236
+ # Slice with step should raise NotImplementedError
237
+ with self.assertRaises(NotImplementedError):
238
+ _item_to_str(slice("start", "stop", "step"))
239
+
240
+ def test_maybe_tuple_to_str_function(self):
241
+ """Test _maybe_tuple_to_str utility function."""
242
+ # Single item
243
+ self.assertEqual(_maybe_tuple_to_str("single"), "'single'")
244
+
245
+ # Empty tuple
246
+ self.assertEqual(_maybe_tuple_to_str(()), "()")
247
+
248
+ # Non-empty tuple
249
+ tuple_item = ("batch", "features")
250
+ expected = "'batch', 'features'"
251
+ self.assertEqual(_maybe_tuple_to_str(tuple_item), expected)
252
+
253
+ def test_array_module_setting(self):
254
+ """Test that Array has correct module for display."""
255
+ self.assertEqual(Array.__module__, "builtins")
256
+
257
+
258
+ class TestShapeAndSizeTypes(unittest.TestCase):
259
+ """Test shape, size, and axes type annotations."""
260
+
261
+ def test_size_type_variants(self):
262
+ """Test different Size type variants."""
263
+ # Single integer
264
+ size1: Size = 10
265
+ self.assertIsInstance(size1, int)
266
+
267
+ # Tuple of integers
268
+ size2: Size = (3, 4, 5)
269
+ self.assertIsInstance(size2, tuple)
270
+ self.assertTrue(all(isinstance(x, int) for x in size2))
271
+
272
+ # NumPy integers
273
+ size3: Size = np.int32(8)
274
+ self.assertIsInstance(size3, np.integer)
275
+
276
+ # Sequence with mixed NumPy types
277
+ size4: Size = [np.int64(2), 3, np.int32(4)]
278
+ self.assertIsInstance(size4, list)
279
+
280
+ def test_shape_type(self):
281
+ """Test Shape type usage."""
282
+ # 2D shape
283
+ matrix_shape: Shape = (10, 20)
284
+ self.assertEqual(len(matrix_shape), 2)
285
+ self.assertTrue(all(isinstance(x, int) for x in matrix_shape))
286
+
287
+ # 3D shape
288
+ tensor_shape: Shape = (5, 10, 15)
289
+ self.assertEqual(len(tensor_shape), 3)
290
+
291
+ # 1D shape (still a sequence)
292
+ vector_shape: Shape = (100,)
293
+ self.assertEqual(len(vector_shape), 1)
294
+
295
+ def test_axes_type(self):
296
+ """Test Axes type variants."""
297
+ # Single axis
298
+ axis1: Axes = 0
299
+ self.assertIsInstance(axis1, int)
300
+
301
+ # Multiple axes
302
+ axis2: Axes = (0, 2)
303
+ self.assertIsInstance(axis2, tuple)
304
+ self.assertTrue(all(isinstance(x, int) for x in axis2))
305
+
306
+ # List of axes
307
+ axis3: Axes = [1, 3, 4]
308
+ self.assertIsInstance(axis3, list)
309
+
310
+ def test_shape_operations(self):
311
+ """Test operations using shape types."""
312
+
313
+ def create_zeros(shape: Shape) -> jax.Array:
314
+ return jnp.zeros(shape)
315
+
316
+ def sum_along_axes(array: ArrayLike, axes: Axes) -> jax.Array:
317
+ return jnp.sum(array, axis=axes)
318
+
319
+ # Test shape creation
320
+ arr = create_zeros((3, 4))
321
+ self.assertEqual(arr.shape, (3, 4))
322
+
323
+ # Test axes operations
324
+ test_array = jnp.ones((2, 3, 4))
325
+ result1 = sum_along_axes(test_array, 0)
326
+ self.assertEqual(result1.shape, (3, 4))
327
+
328
+ result2 = sum_along_axes(test_array, (0, 2))
329
+ self.assertEqual(result2.shape, (3,))
330
+
331
+
332
+ class TestArrayLikeAndDType(unittest.TestCase):
333
+ """Test ArrayLike and dtype-related types."""
334
+
335
+ def test_arraylike_variants(self):
336
+ """Test different ArrayLike type variants."""
337
+
338
+ def process_data(data: ArrayLike) -> jax.Array:
339
+ return jnp.asarray(data)
340
+
341
+ # JAX array
342
+ jax_array = jnp.array([1, 2, 3])
343
+ result1 = process_data(jax_array)
344
+ self.assertIsInstance(result1, jax.Array)
345
+
346
+ # NumPy array
347
+ numpy_array = np.array([1, 2, 3])
348
+ result2 = process_data(numpy_array)
349
+ self.assertIsInstance(result2, jax.Array)
350
+
351
+ # Python scalars
352
+ result3 = process_data(42)
353
+ self.assertIsInstance(result3, jax.Array)
354
+ self.assertEqual(result3.shape, ())
355
+
356
+ result4 = process_data(3.14)
357
+ self.assertIsInstance(result4, jax.Array)
358
+
359
+ result5 = process_data(True)
360
+ self.assertIsInstance(result5, jax.Array)
361
+
362
+ result6 = process_data(1 + 2j)
363
+ self.assertIsInstance(result6, jax.Array)
364
+
365
+ # NumPy scalars
366
+ result7 = process_data(np.float32(2.5))
367
+ self.assertIsInstance(result7, jax.Array)
368
+
369
+ result8 = process_data(np.bool_(False))
370
+ self.assertIsInstance(result8, jax.Array)
371
+
372
+ # BrainUnit quantities (if available)
373
+ try:
374
+ quantity = 1.5 * u.second
375
+ # Convert to plain array for processing
376
+ result9 = process_data(quantity.mantissa)
377
+ self.assertIsInstance(result9, jax.Array)
378
+ except (AttributeError, TypeError):
379
+ # Skip if BrainUnit quantities not properly set up
380
+ pass
381
+
382
+ def test_dtype_variants(self):
383
+ """Test DType and DTypeLike variants."""
384
+
385
+ def cast_array(array: ArrayLike, dtype: DTypeLike) -> jax.Array:
386
+ return jnp.asarray(array, dtype=dtype)
387
+
388
+ test_data = [1, 2, 3]
389
+
390
+ # String dtype
391
+ result1 = cast_array(test_data, 'float32')
392
+ self.assertEqual(result1.dtype, jnp.float32)
393
+
394
+ # NumPy type
395
+ result2 = cast_array(test_data, np.float32)
396
+ self.assertEqual(result2.dtype, jnp.float32)
397
+
398
+ # Python type
399
+ result3 = cast_array(test_data, float)
400
+ self.assertTrue(jnp.issubdtype(result3.dtype, jnp.floating))
401
+
402
+ # NumPy dtype object
403
+ result4 = cast_array(test_data, np.dtype('int32'))
404
+ self.assertEqual(result4.dtype, jnp.int32)
405
+
406
+ def test_supports_dtype_protocol(self):
407
+ """Test SupportsDType protocol."""
408
+
409
+ def get_dtype(obj: SupportsDType) -> DType:
410
+ return obj.dtype
411
+
412
+ # Test with arrays
413
+ arr = jnp.array([1.0, 2.0])
414
+ dtype = get_dtype(arr)
415
+ self.assertIsInstance(dtype, np.dtype)
416
+
417
+ # Test with NumPy arrays
418
+ np_arr = np.array([1, 2], dtype=np.int64)
419
+ dtype2 = get_dtype(np_arr)
420
+ self.assertEqual(dtype2, np.int64)
421
+
422
+ def test_dtype_alias(self):
423
+ """Test DType alias."""
424
+
425
+ def create_array(shape: Shape, dtype: DType) -> jax.Array:
426
+ return jnp.zeros(shape, dtype=dtype)
427
+
428
+ arr = create_array((3, 4), np.float32)
429
+ self.assertEqual(arr.shape, (3, 4))
430
+ self.assertEqual(arr.dtype, jnp.float32)
431
+
432
+
433
+ class TestPyTreeTypes(unittest.TestCase):
434
+ """Test PyTree type annotations."""
435
+
436
+ def test_pytree_basic_usage(self):
437
+ """Test basic PyTree type usage."""
438
+
439
+ def tree_function(tree: PyTree[float]) -> PyTree[float]:
440
+ return jax.tree_util.tree_map(lambda x: x * 2, tree)
441
+
442
+ # Test with different PyTree structures
443
+ tree1 = {"a": 1.0, "b": 2.0}
444
+ result1 = tree_function(tree1)
445
+ self.assertAlmostEqual(result1["a"], 2.0)
446
+ self.assertAlmostEqual(result1["b"], 4.0)
447
+
448
+ tree2 = [1.0, 2.0, 3.0]
449
+ result2 = tree_function(tree2)
450
+ expected = [2.0, 4.0, 6.0]
451
+ for i, (actual, expect) in enumerate(zip(result2, expected)):
452
+ self.assertAlmostEqual(actual, expect)
453
+
454
+ def test_pytree_with_structure(self):
455
+ """Test PyTree with structure annotations."""
456
+
457
+ def structured_function(tree: PyTree[float, "T"]) -> PyTree[float, "T"]:
458
+ return jax.tree_util.tree_map(lambda x: x + 1, tree)
459
+
460
+ # Test that function works with various structures
461
+ tree = {"weights": 1.0, "bias": 2.0}
462
+ result = structured_function(tree)
463
+ self.assertAlmostEqual(result["weights"], 2.0)
464
+ self.assertAlmostEqual(result["bias"], 3.0)
465
+
466
+ def test_pytree_instantiation_error(self):
467
+ """Test that PyTree cannot be instantiated."""
468
+ with self.assertRaises(RuntimeError):
469
+ PyTree()
470
+
471
+ def test_pytree_subscripting(self):
472
+ """Test PyTree subscripting behavior."""
473
+ # Single type parameter
474
+ pytree_type = PyTree[float]
475
+ self.assertTrue(hasattr(pytree_type, 'leaftype'))
476
+ self.assertEqual(pytree_type.leaftype, float)
477
+
478
+ # Type and structure parameters
479
+ pytree_structured = PyTree[int, "T"]
480
+ self.assertTrue(hasattr(pytree_structured, 'leaftype'))
481
+ self.assertTrue(hasattr(pytree_structured, 'structure'))
482
+ self.assertEqual(pytree_structured.leaftype, int)
483
+ self.assertEqual(pytree_structured.structure, "T")
484
+
485
+ def test_pytree_structure_validation(self):
486
+ """Test PyTree structure validation."""
487
+ # Valid structure names
488
+ valid_structures = ["T", "S T", "... T", "T ...", "foo bar"]
489
+ for structure in valid_structures:
490
+ PyTree[float, structure]
491
+
492
+ # Invalid structures
493
+ with self.assertRaises(ValueError):
494
+ PyTree[float, ""] # Empty string
495
+
496
+ with self.assertRaises(ValueError):
497
+ PyTree[float, "invalid-identifier"] # Invalid identifier
498
+
499
+ with self.assertRaises(ValueError):
500
+ PyTree[float, "123abc"] # Starts with number
501
+
502
+ def test_pytree_tuple_length_validation(self):
503
+ """Test PyTree tuple parameter validation."""
504
+ # Valid 2-tuple
505
+ PyTree[float, "T"]
506
+
507
+ # Invalid tuple lengths
508
+ with self.assertRaises(ValueError):
509
+ PyTree[float, "T", "extra"] # 3-tuple
510
+
511
+ with self.assertRaises(ValueError):
512
+ PyTree[float,] # 1-tuple with trailing comma would be (float,)
513
+
514
+
515
+ class TestRandomTypes(unittest.TestCase):
516
+ """Test random number generation types."""
517
+
518
+ def test_seed_or_key_variants(self):
519
+ """Test SeedOrKey type variants."""
520
+
521
+ def generate_random(key: SeedOrKey, shape: Shape) -> jax.Array:
522
+ if isinstance(key, int):
523
+ key = jr.PRNGKey(key)
524
+ return jr.normal(key, shape)
525
+
526
+ # Integer seed
527
+ result1 = generate_random(42, (3, 4))
528
+ self.assertEqual(result1.shape, (3, 4))
529
+
530
+ # JAX PRNG key
531
+ jax_key = jr.PRNGKey(123)
532
+ result2 = generate_random(jax_key, (5,))
533
+ self.assertEqual(result2.shape, (5,))
534
+
535
+ # NumPy array key
536
+ np_key = np.array([1, 2], dtype=np.uint32)
537
+ result3 = generate_random(np_key, (2, 2))
538
+ self.assertEqual(result3.shape, (2, 2))
539
+
540
+ def test_reproducibility_with_seeds(self):
541
+ """Test that same seeds produce same results."""
542
+
543
+ def generate_data(seed: SeedOrKey) -> jax.Array:
544
+ if isinstance(seed, int):
545
+ key = jr.PRNGKey(seed)
546
+ else:
547
+ key = seed
548
+ return jr.normal(key, (5,))
549
+
550
+ # Same integer seeds
551
+ result1 = generate_data(42)
552
+ result2 = generate_data(42)
553
+ np.testing.assert_array_equal(result1, result2)
554
+
555
+ # Same JAX keys
556
+ key = jr.PRNGKey(999)
557
+ result3 = generate_data(key)
558
+ result4 = generate_data(key)
559
+ np.testing.assert_array_equal(result3, result4)
560
+
561
+
562
+ class TestUtilityTypes(unittest.TestCase):
563
+ """Test utility types and edge cases."""
564
+
565
+ def test_missing_sentinel(self):
566
+ """Test Missing sentinel class."""
567
+ _MISSING = Missing()
568
+
569
+ def function_with_optional_param(value: Union[int, None, Missing] = _MISSING):
570
+ if value is _MISSING:
571
+ return "no_value"
572
+ elif value is None:
573
+ return "explicit_none"
574
+ else:
575
+ return f"value_{value}"
576
+
577
+ # Test different call patterns
578
+ self.assertEqual(function_with_optional_param(), "no_value")
579
+ self.assertEqual(function_with_optional_param(None), "explicit_none")
580
+ self.assertEqual(function_with_optional_param(42), "value_42")
581
+
582
+ # Test that different Missing instances are distinct objects
583
+ missing1 = Missing()
584
+ missing2 = Missing()
585
+ self.assertIsNot(missing1, missing2) # Different instances
586
+ # Note: Missing doesn't define __eq__, so != comparison uses identity
587
+
588
+ def test_type_variables(self):
589
+ """Test type variables are properly defined."""
590
+ # Test that type variables exist
591
+ self.assertIsNotNone(K)
592
+ self.assertIsNotNone(_T)
593
+ self.assertIsNotNone(_Annotation)
594
+
595
+ # Test that they are TypeVar instances
596
+ from typing import TypeVar
597
+ self.assertIsInstance(K, TypeVar)
598
+ self.assertIsInstance(_T, TypeVar)
599
+ self.assertIsInstance(_Annotation, TypeVar)
600
+
601
+ def test_internal_array_type(self):
602
+ """Test internal _Array type."""
603
+ # Test that _Array exists and has proper module
604
+ self.assertIsNotNone(_Array)
605
+ self.assertEqual(_Array.__module__, "builtins")
606
+
607
+ # Test that it can be parameterized
608
+ parameterized = _Array[str]
609
+ self.assertIsNotNone(parameterized)
610
+
611
+
612
+ class TestRealWorldUsagePattterns(unittest.TestCase):
613
+ """Test real-world usage patterns and integration."""
614
+
615
+ def test_neural_network_typing(self):
616
+ """Test typing patterns common in neural networks."""
617
+
618
+ def linear_layer(
619
+ x: Array["batch, in_features"],
620
+ weight: Array["out_features, in_features"],
621
+ bias: Array["out_features"]
622
+ ) -> Array["batch, out_features"]:
623
+ return x @ weight.T + bias
624
+
625
+ # Test with actual arrays
626
+ batch_size, in_features, out_features = 32, 128, 64
627
+ key = jr.PRNGKey(42)
628
+ key1, key2, key3 = jr.split(key, 3)
629
+ x = jr.normal(key1, (batch_size, in_features))
630
+ weight = jr.normal(key2, (out_features, in_features))
631
+ bias = jr.normal(key3, (out_features,))
632
+
633
+ result = linear_layer(x, weight, bias)
634
+ self.assertEqual(result.shape, (batch_size, out_features))
635
+
636
+ def test_pytree_parameter_filtering(self):
637
+ """Test PyTree filtering patterns."""
638
+
639
+ def extract_weights(params: PyTree[ArrayLike]) -> PyTree[ArrayLike]:
640
+ # Mock filtering - in real code this would use jax.tree_util
641
+ return jax.tree_util.tree_map(lambda x: x, params)
642
+
643
+ # Test with parameter structure
644
+ params = {
645
+ "layer1": {"weight": jnp.ones((10, 5)), "bias": jnp.zeros(10)},
646
+ "layer2": {"weight": jnp.ones((5, 3)), "bias": jnp.zeros(5)}
647
+ }
648
+
649
+ result = extract_weights(params)
650
+ self.assertIsInstance(result, dict)
651
+ self.assertIn("layer1", result)
652
+ self.assertIn("layer2", result)
653
+
654
+ def test_mixed_type_operations(self):
655
+ """Test operations mixing different typed inputs."""
656
+
657
+ def process_mixed_data(
658
+ arrays: ArrayLike,
659
+ shape: Shape,
660
+ dtype: DTypeLike,
661
+ seed: SeedOrKey
662
+ ) -> jax.Array:
663
+ # Convert inputs
664
+ data = jnp.asarray(arrays, dtype=dtype)
665
+ key = jr.PRNGKey(seed) if isinstance(seed, int) else seed
666
+
667
+ # Generate noise and add to data
668
+ noise = jr.normal(key, shape) * 0.1
669
+ return data.reshape(shape) + noise
670
+
671
+ # Test with mixed inputs
672
+ result = process_mixed_data(
673
+ arrays=[1, 2, 3, 4],
674
+ shape=(2, 2),
675
+ dtype='float32',
676
+ seed=42
677
+ )
678
+
679
+ self.assertEqual(result.shape, (2, 2))
680
+ self.assertEqual(result.dtype, jnp.float32)
681
+
682
+ def test_scientific_computing_pattern(self):
683
+ """Test scientific computing usage patterns."""
684
+
685
+ def numerical_integration(
686
+ func: callable,
687
+ bounds: ArrayLike,
688
+ n_points: Size,
689
+ dtype: DTypeLike = jnp.float32 # Use float32 for JAX compatibility
690
+ ) -> jax.Array:
691
+ # Mock numerical integration
692
+ x = jnp.linspace(bounds[0], bounds[1], n_points, dtype=dtype)
693
+ y = jax.vmap(func)(x)
694
+ dx = (bounds[1] - bounds[0]) / n_points
695
+ return jnp.sum(y) * dx
696
+
697
+ # Test with simple bounds (skip units for simplicity)
698
+ bounds_array = jnp.array([0.0, 1.0])
699
+
700
+ result = numerical_integration(
701
+ lambda t: t ** 2,
702
+ bounds_array,
703
+ 1000,
704
+ jnp.float32
705
+ )
706
+
707
+ self.assertIsInstance(result, jax.Array)
708
+ self.assertEqual(result.dtype, jnp.float32)
709
+
710
+
711
+ class TestTypeHintCompatibility(unittest.TestCase):
712
+ """Test compatibility with Python's typing system."""
713
+
714
+ def test_get_type_hints(self):
715
+ """Test that type hints can be retrieved from annotated functions."""
716
+
717
+ def annotated_function(
718
+ arr: ArrayLike,
719
+ shape: Shape,
720
+ dtype: DTypeLike
721
+ ) -> jax.Array:
722
+ return jnp.zeros(shape, dtype=dtype)
723
+
724
+ hints = get_type_hints(annotated_function)
725
+
726
+ # Check that hints are captured
727
+ self.assertIn('arr', hints)
728
+ self.assertIn('shape', hints)
729
+ self.assertIn('dtype', hints)
730
+ self.assertIn('return', hints)
731
+
732
+ def test_isinstance_checks(self):
733
+ """Test isinstance checks with protocol types."""
734
+ # Test Key protocol
735
+ self.assertIsInstance("string", Key)
736
+ self.assertIsInstance(42, Key)
737
+ self.assertIsInstance(3.14, Key)
738
+
739
+ # Test SupportsDType protocol (check for dtype attribute)
740
+ arr = jnp.array([1, 2, 3])
741
+ self.assertTrue(hasattr(arr, 'dtype'))
742
+
743
+ np_arr = np.array([1, 2, 3])
744
+ self.assertTrue(hasattr(np_arr, 'dtype'))
745
+
746
+ # Test that objects without dtype don't have the attribute
747
+ self.assertFalse(hasattr("string", 'dtype'))
748
+
749
+ def test_module_imports(self):
750
+ """Test that all types can be imported correctly."""
751
+ from brainstate.typing import (
752
+ Key, PathParts, Predicate, Filter, Array, ArrayLike,
753
+ Shape, Size, Axes, PyTree, SeedOrKey, DType, DTypeLike,
754
+ Missing
755
+ )
756
+
757
+ # Verify all imports succeeded
758
+ types_to_check = [
759
+ Key, PathParts, Predicate, Filter, Array, ArrayLike,
760
+ Shape, Size, Axes, PyTree, SeedOrKey, DType, DTypeLike,
761
+ Missing
762
+ ]
763
+
764
+ for type_obj in types_to_check:
765
+ self.assertIsNotNone(type_obj)
766
+
767
+ def test_documentation_strings(self):
768
+ """Test that types have proper documentation."""
769
+ documented_types = [
770
+ Key, Array, PyTree, Missing, SupportsDType
771
+ ]
772
+
773
+ for type_obj in documented_types:
774
+ self.assertIsNotNone(type_obj.__doc__)
775
+ self.assertGreater(len(type_obj.__doc__), 10) # Has substantial documentation
776
+
777
+
778
+ if __name__ == '__main__':
779
+ # Run tests with verbose output
780
+ unittest.main(verbosity=2)