brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,780 @@
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""
|
17
|
+
Comprehensive tests for brainstate.typing module.
|
18
|
+
|
19
|
+
This test suite validates all type annotations, protocols, and type utilities
|
20
|
+
provided by the typing module, ensuring they work correctly with JAX, NumPy,
|
21
|
+
and BrainUnit integration.
|
22
|
+
"""
|
23
|
+
|
24
|
+
import unittest
|
25
|
+
from typing import get_type_hints, Union, Any
|
26
|
+
|
27
|
+
import brainunit as u
|
28
|
+
import jax
|
29
|
+
import jax.numpy as jnp
|
30
|
+
import jax.random as jr
|
31
|
+
import numpy as np
|
32
|
+
|
33
|
+
from brainstate.typing import (
|
34
|
+
# Key and path types
|
35
|
+
Key, PathParts, FilterLiteral, Filter,
|
36
|
+
|
37
|
+
# Array types
|
38
|
+
Array, ArrayLike, Shape, Size, Axes, DType, DTypeLike, SupportsDType,
|
39
|
+
|
40
|
+
# PyTree types
|
41
|
+
PyTree,
|
42
|
+
|
43
|
+
# Random types
|
44
|
+
SeedOrKey,
|
45
|
+
|
46
|
+
# Utility types
|
47
|
+
Missing,
|
48
|
+
|
49
|
+
# Type variables
|
50
|
+
K, _T, _Annotation,
|
51
|
+
|
52
|
+
# Internal utilities for testing
|
53
|
+
_item_to_str, _maybe_tuple_to_str, _Array
|
54
|
+
)
|
55
|
+
|
56
|
+
|
57
|
+
class TestKeyProtocol(unittest.TestCase):
|
58
|
+
"""Test the Key protocol and related path types."""
|
59
|
+
|
60
|
+
def setUp(self):
|
61
|
+
"""Set up test fixtures."""
|
62
|
+
self.string_key = "layer1"
|
63
|
+
self.int_key = 42
|
64
|
+
self.float_key = 3.14
|
65
|
+
|
66
|
+
def test_key_protocol_string(self):
|
67
|
+
"""Test that strings implement the Key protocol."""
|
68
|
+
self.assertIsInstance(self.string_key, Key)
|
69
|
+
self.assertTrue(hasattr(self.string_key, '__hash__'))
|
70
|
+
self.assertTrue(hasattr(self.string_key, '__lt__'))
|
71
|
+
|
72
|
+
def test_key_protocol_int(self):
|
73
|
+
"""Test that integers implement the Key protocol."""
|
74
|
+
self.assertIsInstance(self.int_key, Key)
|
75
|
+
self.assertTrue(hasattr(self.int_key, '__hash__'))
|
76
|
+
self.assertTrue(hasattr(self.int_key, '__lt__'))
|
77
|
+
|
78
|
+
def test_key_ordering(self):
|
79
|
+
"""Test that keys can be ordered."""
|
80
|
+
self.assertTrue("a" < "b")
|
81
|
+
self.assertTrue(1 < 2)
|
82
|
+
self.assertTrue(1.0 < 2.0)
|
83
|
+
|
84
|
+
def test_custom_key_class(self):
|
85
|
+
"""Test custom class implementing Key protocol."""
|
86
|
+
|
87
|
+
class CustomKey:
|
88
|
+
def __init__(self, name: str):
|
89
|
+
self.name = name
|
90
|
+
|
91
|
+
def __hash__(self) -> int:
|
92
|
+
return hash(self.name)
|
93
|
+
|
94
|
+
def __eq__(self, other) -> bool:
|
95
|
+
return isinstance(other, CustomKey) and self.name == other.name
|
96
|
+
|
97
|
+
def __lt__(self, other) -> bool:
|
98
|
+
return isinstance(other, CustomKey) and self.name < other.name
|
99
|
+
|
100
|
+
key1 = CustomKey("first")
|
101
|
+
key2 = CustomKey("second")
|
102
|
+
|
103
|
+
self.assertIsInstance(key1, Key)
|
104
|
+
self.assertTrue(key1 < key2)
|
105
|
+
self.assertEqual(hash(key1), hash(CustomKey("first")))
|
106
|
+
|
107
|
+
def test_path_parts(self):
|
108
|
+
"""Test PathParts type usage."""
|
109
|
+
# Simple path
|
110
|
+
path1: PathParts = ("model", "layers", 0, "weights")
|
111
|
+
self.assertEqual(len(path1), 4)
|
112
|
+
self.assertIsInstance(path1[0], str)
|
113
|
+
self.assertIsInstance(path1[2], int)
|
114
|
+
|
115
|
+
# Empty path
|
116
|
+
path2: PathParts = ()
|
117
|
+
self.assertEqual(len(path2), 0)
|
118
|
+
|
119
|
+
# Mixed types path
|
120
|
+
path3: PathParts = ("root", 1, "sub", 2.5)
|
121
|
+
self.assertEqual(len(path3), 4)
|
122
|
+
|
123
|
+
def test_predicate_functions(self):
|
124
|
+
"""Test Predicate function type."""
|
125
|
+
|
126
|
+
def is_weight_matrix(path: PathParts, value: Any) -> bool:
|
127
|
+
return len(path) > 0 and "weight" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 2
|
128
|
+
|
129
|
+
def is_bias_vector(path: PathParts, value: Any) -> bool:
|
130
|
+
return len(path) > 0 and "bias" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 1
|
131
|
+
|
132
|
+
# Test with mock data
|
133
|
+
weight_path: PathParts = ("layer", "weight")
|
134
|
+
bias_path: PathParts = ("layer", "bias")
|
135
|
+
|
136
|
+
weight_matrix = np.random.randn(10, 5)
|
137
|
+
bias_vector = np.random.randn(5)
|
138
|
+
|
139
|
+
self.assertTrue(is_weight_matrix(weight_path, weight_matrix))
|
140
|
+
self.assertFalse(is_weight_matrix(bias_path, bias_vector))
|
141
|
+
self.assertTrue(is_bias_vector(bias_path, bias_vector))
|
142
|
+
self.assertFalse(is_bias_vector(weight_path, weight_matrix))
|
143
|
+
|
144
|
+
def test_filter_types(self):
|
145
|
+
"""Test various filter type combinations."""
|
146
|
+
# FilterLiteral types
|
147
|
+
type_filter: FilterLiteral = float
|
148
|
+
string_filter: FilterLiteral = "weight"
|
149
|
+
predicate_filter: FilterLiteral = lambda path, x: hasattr(x, 'ndim')
|
150
|
+
bool_filter: FilterLiteral = True
|
151
|
+
ellipsis_filter: FilterLiteral = ...
|
152
|
+
none_filter: FilterLiteral = None
|
153
|
+
|
154
|
+
# Combined filters
|
155
|
+
tuple_filter: Filter = (float, "weight")
|
156
|
+
list_filter: Filter = [int, float, "bias"]
|
157
|
+
nested_filter: Filter = [
|
158
|
+
("weight", lambda p, x: hasattr(x, 'ndim') and x.ndim == 2),
|
159
|
+
("bias", lambda p, x: hasattr(x, 'ndim') and x.ndim == 1),
|
160
|
+
]
|
161
|
+
|
162
|
+
# Verify types are correctly assigned
|
163
|
+
self.assertIsInstance(type_filter, type)
|
164
|
+
self.assertIsInstance(string_filter, str)
|
165
|
+
self.assertTrue(callable(predicate_filter))
|
166
|
+
self.assertIsInstance(bool_filter, bool)
|
167
|
+
self.assertEqual(ellipsis_filter, ...)
|
168
|
+
self.assertIsNone(none_filter)
|
169
|
+
|
170
|
+
|
171
|
+
class TestArrayAnnotations(unittest.TestCase):
|
172
|
+
"""Test Array type annotations and related utilities."""
|
173
|
+
|
174
|
+
def test_array_basic_annotation(self):
|
175
|
+
"""Test basic Array type annotation."""
|
176
|
+
|
177
|
+
def process_array(x: Array) -> Array:
|
178
|
+
return x * 2
|
179
|
+
|
180
|
+
# Check that function can be called with various array types
|
181
|
+
jax_array = jnp.array([1, 2, 3])
|
182
|
+
numpy_array = np.array([1, 2, 3])
|
183
|
+
|
184
|
+
result1 = process_array(jax_array)
|
185
|
+
result2 = process_array(numpy_array)
|
186
|
+
|
187
|
+
self.assertIsInstance(result1, jax.Array)
|
188
|
+
self.assertIsInstance(result2, np.ndarray)
|
189
|
+
|
190
|
+
def test_array_shape_annotation(self):
|
191
|
+
"""Test Array with shape annotations."""
|
192
|
+
|
193
|
+
def matrix_multiply(a: Array["m, n"], b: Array["n, k"]) -> Array["m, k"]:
|
194
|
+
return a @ b
|
195
|
+
|
196
|
+
# Test with compatible shapes
|
197
|
+
key = jax.random.PRNGKey(42)
|
198
|
+
key1, key2 = jax.random.split(key)
|
199
|
+
a = jax.random.normal(key1, (3, 4))
|
200
|
+
b = jax.random.normal(key2, (4, 5))
|
201
|
+
result = matrix_multiply(a, b)
|
202
|
+
|
203
|
+
self.assertEqual(result.shape, (3, 5))
|
204
|
+
|
205
|
+
def test_array_class_getitem(self):
|
206
|
+
"""Test Array.__class_getitem__ functionality."""
|
207
|
+
# Test shape annotation creation
|
208
|
+
shaped_array = Array["batch, features"]
|
209
|
+
self.assertIsNotNone(shaped_array)
|
210
|
+
self.assertTrue(hasattr(shaped_array, '__origin__'))
|
211
|
+
|
212
|
+
# Test complex shape annotation
|
213
|
+
complex_array = Array["batch, seq_len, d_model"]
|
214
|
+
self.assertIsNotNone(complex_array)
|
215
|
+
|
216
|
+
# Test with ellipsis
|
217
|
+
flexible_array = Array["batch, ..."]
|
218
|
+
self.assertIsNotNone(flexible_array)
|
219
|
+
|
220
|
+
def test_item_to_str_function(self):
|
221
|
+
"""Test _item_to_str utility function."""
|
222
|
+
# String item
|
223
|
+
self.assertEqual(_item_to_str("batch"), "'batch'")
|
224
|
+
|
225
|
+
# Type item
|
226
|
+
self.assertEqual(_item_to_str(float), "float")
|
227
|
+
|
228
|
+
# Ellipsis item
|
229
|
+
self.assertEqual(_item_to_str(...), "...")
|
230
|
+
|
231
|
+
# Slice item
|
232
|
+
slice_item = slice("start", "stop")
|
233
|
+
expected = "'start': 'stop'"
|
234
|
+
self.assertEqual(_item_to_str(slice_item), expected)
|
235
|
+
|
236
|
+
# Slice with step should raise NotImplementedError
|
237
|
+
with self.assertRaises(NotImplementedError):
|
238
|
+
_item_to_str(slice("start", "stop", "step"))
|
239
|
+
|
240
|
+
def test_maybe_tuple_to_str_function(self):
|
241
|
+
"""Test _maybe_tuple_to_str utility function."""
|
242
|
+
# Single item
|
243
|
+
self.assertEqual(_maybe_tuple_to_str("single"), "'single'")
|
244
|
+
|
245
|
+
# Empty tuple
|
246
|
+
self.assertEqual(_maybe_tuple_to_str(()), "()")
|
247
|
+
|
248
|
+
# Non-empty tuple
|
249
|
+
tuple_item = ("batch", "features")
|
250
|
+
expected = "'batch', 'features'"
|
251
|
+
self.assertEqual(_maybe_tuple_to_str(tuple_item), expected)
|
252
|
+
|
253
|
+
def test_array_module_setting(self):
|
254
|
+
"""Test that Array has correct module for display."""
|
255
|
+
self.assertEqual(Array.__module__, "builtins")
|
256
|
+
|
257
|
+
|
258
|
+
class TestShapeAndSizeTypes(unittest.TestCase):
|
259
|
+
"""Test shape, size, and axes type annotations."""
|
260
|
+
|
261
|
+
def test_size_type_variants(self):
|
262
|
+
"""Test different Size type variants."""
|
263
|
+
# Single integer
|
264
|
+
size1: Size = 10
|
265
|
+
self.assertIsInstance(size1, int)
|
266
|
+
|
267
|
+
# Tuple of integers
|
268
|
+
size2: Size = (3, 4, 5)
|
269
|
+
self.assertIsInstance(size2, tuple)
|
270
|
+
self.assertTrue(all(isinstance(x, int) for x in size2))
|
271
|
+
|
272
|
+
# NumPy integers
|
273
|
+
size3: Size = np.int32(8)
|
274
|
+
self.assertIsInstance(size3, np.integer)
|
275
|
+
|
276
|
+
# Sequence with mixed NumPy types
|
277
|
+
size4: Size = [np.int64(2), 3, np.int32(4)]
|
278
|
+
self.assertIsInstance(size4, list)
|
279
|
+
|
280
|
+
def test_shape_type(self):
|
281
|
+
"""Test Shape type usage."""
|
282
|
+
# 2D shape
|
283
|
+
matrix_shape: Shape = (10, 20)
|
284
|
+
self.assertEqual(len(matrix_shape), 2)
|
285
|
+
self.assertTrue(all(isinstance(x, int) for x in matrix_shape))
|
286
|
+
|
287
|
+
# 3D shape
|
288
|
+
tensor_shape: Shape = (5, 10, 15)
|
289
|
+
self.assertEqual(len(tensor_shape), 3)
|
290
|
+
|
291
|
+
# 1D shape (still a sequence)
|
292
|
+
vector_shape: Shape = (100,)
|
293
|
+
self.assertEqual(len(vector_shape), 1)
|
294
|
+
|
295
|
+
def test_axes_type(self):
|
296
|
+
"""Test Axes type variants."""
|
297
|
+
# Single axis
|
298
|
+
axis1: Axes = 0
|
299
|
+
self.assertIsInstance(axis1, int)
|
300
|
+
|
301
|
+
# Multiple axes
|
302
|
+
axis2: Axes = (0, 2)
|
303
|
+
self.assertIsInstance(axis2, tuple)
|
304
|
+
self.assertTrue(all(isinstance(x, int) for x in axis2))
|
305
|
+
|
306
|
+
# List of axes
|
307
|
+
axis3: Axes = [1, 3, 4]
|
308
|
+
self.assertIsInstance(axis3, list)
|
309
|
+
|
310
|
+
def test_shape_operations(self):
|
311
|
+
"""Test operations using shape types."""
|
312
|
+
|
313
|
+
def create_zeros(shape: Shape) -> jax.Array:
|
314
|
+
return jnp.zeros(shape)
|
315
|
+
|
316
|
+
def sum_along_axes(array: ArrayLike, axes: Axes) -> jax.Array:
|
317
|
+
return jnp.sum(array, axis=axes)
|
318
|
+
|
319
|
+
# Test shape creation
|
320
|
+
arr = create_zeros((3, 4))
|
321
|
+
self.assertEqual(arr.shape, (3, 4))
|
322
|
+
|
323
|
+
# Test axes operations
|
324
|
+
test_array = jnp.ones((2, 3, 4))
|
325
|
+
result1 = sum_along_axes(test_array, 0)
|
326
|
+
self.assertEqual(result1.shape, (3, 4))
|
327
|
+
|
328
|
+
result2 = sum_along_axes(test_array, (0, 2))
|
329
|
+
self.assertEqual(result2.shape, (3,))
|
330
|
+
|
331
|
+
|
332
|
+
class TestArrayLikeAndDType(unittest.TestCase):
|
333
|
+
"""Test ArrayLike and dtype-related types."""
|
334
|
+
|
335
|
+
def test_arraylike_variants(self):
|
336
|
+
"""Test different ArrayLike type variants."""
|
337
|
+
|
338
|
+
def process_data(data: ArrayLike) -> jax.Array:
|
339
|
+
return jnp.asarray(data)
|
340
|
+
|
341
|
+
# JAX array
|
342
|
+
jax_array = jnp.array([1, 2, 3])
|
343
|
+
result1 = process_data(jax_array)
|
344
|
+
self.assertIsInstance(result1, jax.Array)
|
345
|
+
|
346
|
+
# NumPy array
|
347
|
+
numpy_array = np.array([1, 2, 3])
|
348
|
+
result2 = process_data(numpy_array)
|
349
|
+
self.assertIsInstance(result2, jax.Array)
|
350
|
+
|
351
|
+
# Python scalars
|
352
|
+
result3 = process_data(42)
|
353
|
+
self.assertIsInstance(result3, jax.Array)
|
354
|
+
self.assertEqual(result3.shape, ())
|
355
|
+
|
356
|
+
result4 = process_data(3.14)
|
357
|
+
self.assertIsInstance(result4, jax.Array)
|
358
|
+
|
359
|
+
result5 = process_data(True)
|
360
|
+
self.assertIsInstance(result5, jax.Array)
|
361
|
+
|
362
|
+
result6 = process_data(1 + 2j)
|
363
|
+
self.assertIsInstance(result6, jax.Array)
|
364
|
+
|
365
|
+
# NumPy scalars
|
366
|
+
result7 = process_data(np.float32(2.5))
|
367
|
+
self.assertIsInstance(result7, jax.Array)
|
368
|
+
|
369
|
+
result8 = process_data(np.bool_(False))
|
370
|
+
self.assertIsInstance(result8, jax.Array)
|
371
|
+
|
372
|
+
# BrainUnit quantities (if available)
|
373
|
+
try:
|
374
|
+
quantity = 1.5 * u.second
|
375
|
+
# Convert to plain array for processing
|
376
|
+
result9 = process_data(quantity.mantissa)
|
377
|
+
self.assertIsInstance(result9, jax.Array)
|
378
|
+
except (AttributeError, TypeError):
|
379
|
+
# Skip if BrainUnit quantities not properly set up
|
380
|
+
pass
|
381
|
+
|
382
|
+
def test_dtype_variants(self):
|
383
|
+
"""Test DType and DTypeLike variants."""
|
384
|
+
|
385
|
+
def cast_array(array: ArrayLike, dtype: DTypeLike) -> jax.Array:
|
386
|
+
return jnp.asarray(array, dtype=dtype)
|
387
|
+
|
388
|
+
test_data = [1, 2, 3]
|
389
|
+
|
390
|
+
# String dtype
|
391
|
+
result1 = cast_array(test_data, 'float32')
|
392
|
+
self.assertEqual(result1.dtype, jnp.float32)
|
393
|
+
|
394
|
+
# NumPy type
|
395
|
+
result2 = cast_array(test_data, np.float32)
|
396
|
+
self.assertEqual(result2.dtype, jnp.float32)
|
397
|
+
|
398
|
+
# Python type
|
399
|
+
result3 = cast_array(test_data, float)
|
400
|
+
self.assertTrue(jnp.issubdtype(result3.dtype, jnp.floating))
|
401
|
+
|
402
|
+
# NumPy dtype object
|
403
|
+
result4 = cast_array(test_data, np.dtype('int32'))
|
404
|
+
self.assertEqual(result4.dtype, jnp.int32)
|
405
|
+
|
406
|
+
def test_supports_dtype_protocol(self):
|
407
|
+
"""Test SupportsDType protocol."""
|
408
|
+
|
409
|
+
def get_dtype(obj: SupportsDType) -> DType:
|
410
|
+
return obj.dtype
|
411
|
+
|
412
|
+
# Test with arrays
|
413
|
+
arr = jnp.array([1.0, 2.0])
|
414
|
+
dtype = get_dtype(arr)
|
415
|
+
self.assertIsInstance(dtype, np.dtype)
|
416
|
+
|
417
|
+
# Test with NumPy arrays
|
418
|
+
np_arr = np.array([1, 2], dtype=np.int64)
|
419
|
+
dtype2 = get_dtype(np_arr)
|
420
|
+
self.assertEqual(dtype2, np.int64)
|
421
|
+
|
422
|
+
def test_dtype_alias(self):
|
423
|
+
"""Test DType alias."""
|
424
|
+
|
425
|
+
def create_array(shape: Shape, dtype: DType) -> jax.Array:
|
426
|
+
return jnp.zeros(shape, dtype=dtype)
|
427
|
+
|
428
|
+
arr = create_array((3, 4), np.float32)
|
429
|
+
self.assertEqual(arr.shape, (3, 4))
|
430
|
+
self.assertEqual(arr.dtype, jnp.float32)
|
431
|
+
|
432
|
+
|
433
|
+
class TestPyTreeTypes(unittest.TestCase):
|
434
|
+
"""Test PyTree type annotations."""
|
435
|
+
|
436
|
+
def test_pytree_basic_usage(self):
|
437
|
+
"""Test basic PyTree type usage."""
|
438
|
+
|
439
|
+
def tree_function(tree: PyTree[float]) -> PyTree[float]:
|
440
|
+
return jax.tree_util.tree_map(lambda x: x * 2, tree)
|
441
|
+
|
442
|
+
# Test with different PyTree structures
|
443
|
+
tree1 = {"a": 1.0, "b": 2.0}
|
444
|
+
result1 = tree_function(tree1)
|
445
|
+
self.assertAlmostEqual(result1["a"], 2.0)
|
446
|
+
self.assertAlmostEqual(result1["b"], 4.0)
|
447
|
+
|
448
|
+
tree2 = [1.0, 2.0, 3.0]
|
449
|
+
result2 = tree_function(tree2)
|
450
|
+
expected = [2.0, 4.0, 6.0]
|
451
|
+
for i, (actual, expect) in enumerate(zip(result2, expected)):
|
452
|
+
self.assertAlmostEqual(actual, expect)
|
453
|
+
|
454
|
+
def test_pytree_with_structure(self):
|
455
|
+
"""Test PyTree with structure annotations."""
|
456
|
+
|
457
|
+
def structured_function(tree: PyTree[float, "T"]) -> PyTree[float, "T"]:
|
458
|
+
return jax.tree_util.tree_map(lambda x: x + 1, tree)
|
459
|
+
|
460
|
+
# Test that function works with various structures
|
461
|
+
tree = {"weights": 1.0, "bias": 2.0}
|
462
|
+
result = structured_function(tree)
|
463
|
+
self.assertAlmostEqual(result["weights"], 2.0)
|
464
|
+
self.assertAlmostEqual(result["bias"], 3.0)
|
465
|
+
|
466
|
+
def test_pytree_instantiation_error(self):
|
467
|
+
"""Test that PyTree cannot be instantiated."""
|
468
|
+
with self.assertRaises(RuntimeError):
|
469
|
+
PyTree()
|
470
|
+
|
471
|
+
def test_pytree_subscripting(self):
|
472
|
+
"""Test PyTree subscripting behavior."""
|
473
|
+
# Single type parameter
|
474
|
+
pytree_type = PyTree[float]
|
475
|
+
self.assertTrue(hasattr(pytree_type, 'leaftype'))
|
476
|
+
self.assertEqual(pytree_type.leaftype, float)
|
477
|
+
|
478
|
+
# Type and structure parameters
|
479
|
+
pytree_structured = PyTree[int, "T"]
|
480
|
+
self.assertTrue(hasattr(pytree_structured, 'leaftype'))
|
481
|
+
self.assertTrue(hasattr(pytree_structured, 'structure'))
|
482
|
+
self.assertEqual(pytree_structured.leaftype, int)
|
483
|
+
self.assertEqual(pytree_structured.structure, "T")
|
484
|
+
|
485
|
+
def test_pytree_structure_validation(self):
|
486
|
+
"""Test PyTree structure validation."""
|
487
|
+
# Valid structure names
|
488
|
+
valid_structures = ["T", "S T", "... T", "T ...", "foo bar"]
|
489
|
+
for structure in valid_structures:
|
490
|
+
PyTree[float, structure]
|
491
|
+
|
492
|
+
# Invalid structures
|
493
|
+
with self.assertRaises(ValueError):
|
494
|
+
PyTree[float, ""] # Empty string
|
495
|
+
|
496
|
+
with self.assertRaises(ValueError):
|
497
|
+
PyTree[float, "invalid-identifier"] # Invalid identifier
|
498
|
+
|
499
|
+
with self.assertRaises(ValueError):
|
500
|
+
PyTree[float, "123abc"] # Starts with number
|
501
|
+
|
502
|
+
def test_pytree_tuple_length_validation(self):
|
503
|
+
"""Test PyTree tuple parameter validation."""
|
504
|
+
# Valid 2-tuple
|
505
|
+
PyTree[float, "T"]
|
506
|
+
|
507
|
+
# Invalid tuple lengths
|
508
|
+
with self.assertRaises(ValueError):
|
509
|
+
PyTree[float, "T", "extra"] # 3-tuple
|
510
|
+
|
511
|
+
with self.assertRaises(ValueError):
|
512
|
+
PyTree[float,] # 1-tuple with trailing comma would be (float,)
|
513
|
+
|
514
|
+
|
515
|
+
class TestRandomTypes(unittest.TestCase):
|
516
|
+
"""Test random number generation types."""
|
517
|
+
|
518
|
+
def test_seed_or_key_variants(self):
|
519
|
+
"""Test SeedOrKey type variants."""
|
520
|
+
|
521
|
+
def generate_random(key: SeedOrKey, shape: Shape) -> jax.Array:
|
522
|
+
if isinstance(key, int):
|
523
|
+
key = jr.PRNGKey(key)
|
524
|
+
return jr.normal(key, shape)
|
525
|
+
|
526
|
+
# Integer seed
|
527
|
+
result1 = generate_random(42, (3, 4))
|
528
|
+
self.assertEqual(result1.shape, (3, 4))
|
529
|
+
|
530
|
+
# JAX PRNG key
|
531
|
+
jax_key = jr.PRNGKey(123)
|
532
|
+
result2 = generate_random(jax_key, (5,))
|
533
|
+
self.assertEqual(result2.shape, (5,))
|
534
|
+
|
535
|
+
# NumPy array key
|
536
|
+
np_key = np.array([1, 2], dtype=np.uint32)
|
537
|
+
result3 = generate_random(np_key, (2, 2))
|
538
|
+
self.assertEqual(result3.shape, (2, 2))
|
539
|
+
|
540
|
+
def test_reproducibility_with_seeds(self):
|
541
|
+
"""Test that same seeds produce same results."""
|
542
|
+
|
543
|
+
def generate_data(seed: SeedOrKey) -> jax.Array:
|
544
|
+
if isinstance(seed, int):
|
545
|
+
key = jr.PRNGKey(seed)
|
546
|
+
else:
|
547
|
+
key = seed
|
548
|
+
return jr.normal(key, (5,))
|
549
|
+
|
550
|
+
# Same integer seeds
|
551
|
+
result1 = generate_data(42)
|
552
|
+
result2 = generate_data(42)
|
553
|
+
np.testing.assert_array_equal(result1, result2)
|
554
|
+
|
555
|
+
# Same JAX keys
|
556
|
+
key = jr.PRNGKey(999)
|
557
|
+
result3 = generate_data(key)
|
558
|
+
result4 = generate_data(key)
|
559
|
+
np.testing.assert_array_equal(result3, result4)
|
560
|
+
|
561
|
+
|
562
|
+
class TestUtilityTypes(unittest.TestCase):
|
563
|
+
"""Test utility types and edge cases."""
|
564
|
+
|
565
|
+
def test_missing_sentinel(self):
|
566
|
+
"""Test Missing sentinel class."""
|
567
|
+
_MISSING = Missing()
|
568
|
+
|
569
|
+
def function_with_optional_param(value: Union[int, None, Missing] = _MISSING):
|
570
|
+
if value is _MISSING:
|
571
|
+
return "no_value"
|
572
|
+
elif value is None:
|
573
|
+
return "explicit_none"
|
574
|
+
else:
|
575
|
+
return f"value_{value}"
|
576
|
+
|
577
|
+
# Test different call patterns
|
578
|
+
self.assertEqual(function_with_optional_param(), "no_value")
|
579
|
+
self.assertEqual(function_with_optional_param(None), "explicit_none")
|
580
|
+
self.assertEqual(function_with_optional_param(42), "value_42")
|
581
|
+
|
582
|
+
# Test that different Missing instances are distinct objects
|
583
|
+
missing1 = Missing()
|
584
|
+
missing2 = Missing()
|
585
|
+
self.assertIsNot(missing1, missing2) # Different instances
|
586
|
+
# Note: Missing doesn't define __eq__, so != comparison uses identity
|
587
|
+
|
588
|
+
def test_type_variables(self):
|
589
|
+
"""Test type variables are properly defined."""
|
590
|
+
# Test that type variables exist
|
591
|
+
self.assertIsNotNone(K)
|
592
|
+
self.assertIsNotNone(_T)
|
593
|
+
self.assertIsNotNone(_Annotation)
|
594
|
+
|
595
|
+
# Test that they are TypeVar instances
|
596
|
+
from typing import TypeVar
|
597
|
+
self.assertIsInstance(K, TypeVar)
|
598
|
+
self.assertIsInstance(_T, TypeVar)
|
599
|
+
self.assertIsInstance(_Annotation, TypeVar)
|
600
|
+
|
601
|
+
def test_internal_array_type(self):
|
602
|
+
"""Test internal _Array type."""
|
603
|
+
# Test that _Array exists and has proper module
|
604
|
+
self.assertIsNotNone(_Array)
|
605
|
+
self.assertEqual(_Array.__module__, "builtins")
|
606
|
+
|
607
|
+
# Test that it can be parameterized
|
608
|
+
parameterized = _Array[str]
|
609
|
+
self.assertIsNotNone(parameterized)
|
610
|
+
|
611
|
+
|
612
|
+
class TestRealWorldUsagePattterns(unittest.TestCase):
|
613
|
+
"""Test real-world usage patterns and integration."""
|
614
|
+
|
615
|
+
def test_neural_network_typing(self):
|
616
|
+
"""Test typing patterns common in neural networks."""
|
617
|
+
|
618
|
+
def linear_layer(
|
619
|
+
x: Array["batch, in_features"],
|
620
|
+
weight: Array["out_features, in_features"],
|
621
|
+
bias: Array["out_features"]
|
622
|
+
) -> Array["batch, out_features"]:
|
623
|
+
return x @ weight.T + bias
|
624
|
+
|
625
|
+
# Test with actual arrays
|
626
|
+
batch_size, in_features, out_features = 32, 128, 64
|
627
|
+
key = jr.PRNGKey(42)
|
628
|
+
key1, key2, key3 = jr.split(key, 3)
|
629
|
+
x = jr.normal(key1, (batch_size, in_features))
|
630
|
+
weight = jr.normal(key2, (out_features, in_features))
|
631
|
+
bias = jr.normal(key3, (out_features,))
|
632
|
+
|
633
|
+
result = linear_layer(x, weight, bias)
|
634
|
+
self.assertEqual(result.shape, (batch_size, out_features))
|
635
|
+
|
636
|
+
def test_pytree_parameter_filtering(self):
|
637
|
+
"""Test PyTree filtering patterns."""
|
638
|
+
|
639
|
+
def extract_weights(params: PyTree[ArrayLike]) -> PyTree[ArrayLike]:
|
640
|
+
# Mock filtering - in real code this would use jax.tree_util
|
641
|
+
return jax.tree_util.tree_map(lambda x: x, params)
|
642
|
+
|
643
|
+
# Test with parameter structure
|
644
|
+
params = {
|
645
|
+
"layer1": {"weight": jnp.ones((10, 5)), "bias": jnp.zeros(10)},
|
646
|
+
"layer2": {"weight": jnp.ones((5, 3)), "bias": jnp.zeros(5)}
|
647
|
+
}
|
648
|
+
|
649
|
+
result = extract_weights(params)
|
650
|
+
self.assertIsInstance(result, dict)
|
651
|
+
self.assertIn("layer1", result)
|
652
|
+
self.assertIn("layer2", result)
|
653
|
+
|
654
|
+
def test_mixed_type_operations(self):
|
655
|
+
"""Test operations mixing different typed inputs."""
|
656
|
+
|
657
|
+
def process_mixed_data(
|
658
|
+
arrays: ArrayLike,
|
659
|
+
shape: Shape,
|
660
|
+
dtype: DTypeLike,
|
661
|
+
seed: SeedOrKey
|
662
|
+
) -> jax.Array:
|
663
|
+
# Convert inputs
|
664
|
+
data = jnp.asarray(arrays, dtype=dtype)
|
665
|
+
key = jr.PRNGKey(seed) if isinstance(seed, int) else seed
|
666
|
+
|
667
|
+
# Generate noise and add to data
|
668
|
+
noise = jr.normal(key, shape) * 0.1
|
669
|
+
return data.reshape(shape) + noise
|
670
|
+
|
671
|
+
# Test with mixed inputs
|
672
|
+
result = process_mixed_data(
|
673
|
+
arrays=[1, 2, 3, 4],
|
674
|
+
shape=(2, 2),
|
675
|
+
dtype='float32',
|
676
|
+
seed=42
|
677
|
+
)
|
678
|
+
|
679
|
+
self.assertEqual(result.shape, (2, 2))
|
680
|
+
self.assertEqual(result.dtype, jnp.float32)
|
681
|
+
|
682
|
+
def test_scientific_computing_pattern(self):
|
683
|
+
"""Test scientific computing usage patterns."""
|
684
|
+
|
685
|
+
def numerical_integration(
|
686
|
+
func: callable,
|
687
|
+
bounds: ArrayLike,
|
688
|
+
n_points: Size,
|
689
|
+
dtype: DTypeLike = jnp.float32 # Use float32 for JAX compatibility
|
690
|
+
) -> jax.Array:
|
691
|
+
# Mock numerical integration
|
692
|
+
x = jnp.linspace(bounds[0], bounds[1], n_points, dtype=dtype)
|
693
|
+
y = jax.vmap(func)(x)
|
694
|
+
dx = (bounds[1] - bounds[0]) / n_points
|
695
|
+
return jnp.sum(y) * dx
|
696
|
+
|
697
|
+
# Test with simple bounds (skip units for simplicity)
|
698
|
+
bounds_array = jnp.array([0.0, 1.0])
|
699
|
+
|
700
|
+
result = numerical_integration(
|
701
|
+
lambda t: t ** 2,
|
702
|
+
bounds_array,
|
703
|
+
1000,
|
704
|
+
jnp.float32
|
705
|
+
)
|
706
|
+
|
707
|
+
self.assertIsInstance(result, jax.Array)
|
708
|
+
self.assertEqual(result.dtype, jnp.float32)
|
709
|
+
|
710
|
+
|
711
|
+
class TestTypeHintCompatibility(unittest.TestCase):
|
712
|
+
"""Test compatibility with Python's typing system."""
|
713
|
+
|
714
|
+
def test_get_type_hints(self):
|
715
|
+
"""Test that type hints can be retrieved from annotated functions."""
|
716
|
+
|
717
|
+
def annotated_function(
|
718
|
+
arr: ArrayLike,
|
719
|
+
shape: Shape,
|
720
|
+
dtype: DTypeLike
|
721
|
+
) -> jax.Array:
|
722
|
+
return jnp.zeros(shape, dtype=dtype)
|
723
|
+
|
724
|
+
hints = get_type_hints(annotated_function)
|
725
|
+
|
726
|
+
# Check that hints are captured
|
727
|
+
self.assertIn('arr', hints)
|
728
|
+
self.assertIn('shape', hints)
|
729
|
+
self.assertIn('dtype', hints)
|
730
|
+
self.assertIn('return', hints)
|
731
|
+
|
732
|
+
def test_isinstance_checks(self):
|
733
|
+
"""Test isinstance checks with protocol types."""
|
734
|
+
# Test Key protocol
|
735
|
+
self.assertIsInstance("string", Key)
|
736
|
+
self.assertIsInstance(42, Key)
|
737
|
+
self.assertIsInstance(3.14, Key)
|
738
|
+
|
739
|
+
# Test SupportsDType protocol (check for dtype attribute)
|
740
|
+
arr = jnp.array([1, 2, 3])
|
741
|
+
self.assertTrue(hasattr(arr, 'dtype'))
|
742
|
+
|
743
|
+
np_arr = np.array([1, 2, 3])
|
744
|
+
self.assertTrue(hasattr(np_arr, 'dtype'))
|
745
|
+
|
746
|
+
# Test that objects without dtype don't have the attribute
|
747
|
+
self.assertFalse(hasattr("string", 'dtype'))
|
748
|
+
|
749
|
+
def test_module_imports(self):
|
750
|
+
"""Test that all types can be imported correctly."""
|
751
|
+
from brainstate.typing import (
|
752
|
+
Key, PathParts, Predicate, Filter, Array, ArrayLike,
|
753
|
+
Shape, Size, Axes, PyTree, SeedOrKey, DType, DTypeLike,
|
754
|
+
Missing
|
755
|
+
)
|
756
|
+
|
757
|
+
# Verify all imports succeeded
|
758
|
+
types_to_check = [
|
759
|
+
Key, PathParts, Predicate, Filter, Array, ArrayLike,
|
760
|
+
Shape, Size, Axes, PyTree, SeedOrKey, DType, DTypeLike,
|
761
|
+
Missing
|
762
|
+
]
|
763
|
+
|
764
|
+
for type_obj in types_to_check:
|
765
|
+
self.assertIsNotNone(type_obj)
|
766
|
+
|
767
|
+
def test_documentation_strings(self):
|
768
|
+
"""Test that types have proper documentation."""
|
769
|
+
documented_types = [
|
770
|
+
Key, Array, PyTree, Missing, SupportsDType
|
771
|
+
]
|
772
|
+
|
773
|
+
for type_obj in documented_types:
|
774
|
+
self.assertIsNotNone(type_obj.__doc__)
|
775
|
+
self.assertGreater(len(type_obj.__doc__), 10) # Has substantial documentation
|
776
|
+
|
777
|
+
|
778
|
+
if __name__ == '__main__':
|
779
|
+
# Run tests with verbose output
|
780
|
+
unittest.main(verbosity=2)
|