brainstate 0.1.0.post20250315__py2.py3-none-any.whl → 0.1.0.post20250322__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/augment/_mapping.py +3 -0
- brainstate/augment/_mapping_test.py +35 -1
- brainstate/nn/_collective_ops.py +2 -2
- brainstate/nn/_collective_ops_test.py +11 -0
- {brainstate-0.1.0.post20250315.dist-info → brainstate-0.1.0.post20250322.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20250315.dist-info → brainstate-0.1.0.post20250322.dist-info}/RECORD +9 -9
- {brainstate-0.1.0.post20250315.dist-info → brainstate-0.1.0.post20250322.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250315.dist-info → brainstate-0.1.0.post20250322.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250315.dist-info → brainstate-0.1.0.post20250322.dist-info}/top_level.txt +0 -0
brainstate/augment/_mapping.py
CHANGED
@@ -952,6 +952,9 @@ def _vmap_new_states_transform(
|
|
952
952
|
out_states: Dict[int, Dict] | Any | None = None,
|
953
953
|
):
|
954
954
|
# TODO: How about nested call ``vmap_new_states``?
|
955
|
+
if isinstance(axis_size, int) and axis_size <= 0:
|
956
|
+
raise ValueError(f"axis_size must be greater than 0, got {axis_size}.")
|
957
|
+
|
955
958
|
|
956
959
|
@vmap(
|
957
960
|
in_axes=in_axes,
|
@@ -315,7 +315,6 @@ class TestVmap(unittest.TestCase):
|
|
315
315
|
)
|
316
316
|
|
317
317
|
|
318
|
-
|
319
318
|
class TestMap(unittest.TestCase):
|
320
319
|
def test_map(self):
|
321
320
|
for dim in [(10,), (10, 10), (10, 10, 10)]:
|
@@ -399,3 +398,38 @@ class TestRemoveAxis:
|
|
399
398
|
complex_array = jnp.array([[1 + 1j, 2 + 2j], [3 + 3j, 4 + 4j]])
|
400
399
|
complex_result = _remove_axis(complex_array, 0)
|
401
400
|
assert jnp.allclose(complex_result, jnp.array([1 + 1j, 2 + 2j]))
|
401
|
+
|
402
|
+
|
403
|
+
class TestVMAPNewStatesEdgeCases(unittest.TestCase):
|
404
|
+
|
405
|
+
def test_axis_size_zero(self):
|
406
|
+
foo = brainstate.nn.LIF(3)
|
407
|
+
# Testing that axis_size of 0 raises an error.
|
408
|
+
with self.assertRaises(ValueError):
|
409
|
+
@bst.augment.vmap_new_states(state_tag='new1', axis_size=0)
|
410
|
+
def faulty_init():
|
411
|
+
foo.init_state()
|
412
|
+
|
413
|
+
# Call the decorated function to trigger validation
|
414
|
+
faulty_init()
|
415
|
+
|
416
|
+
def test_axis_size_negative(self):
|
417
|
+
foo = brainstate.nn.LIF(3)
|
418
|
+
# Testing that a negative axis_size raises an error.
|
419
|
+
with self.assertRaises(ValueError):
|
420
|
+
@bst.augment.vmap_new_states(state_tag='new1', axis_size=-3)
|
421
|
+
def faulty_init():
|
422
|
+
foo.init_state()
|
423
|
+
|
424
|
+
faulty_init()
|
425
|
+
|
426
|
+
def test_incompatible_shapes(self):
|
427
|
+
foo = brainstate.nn.LIF(3)
|
428
|
+
# Simulate an incompatible shapes scenario:
|
429
|
+
# We intentionally assign a state with a different shape than expected.
|
430
|
+
@bst.augment.vmap_new_states(state_tag='new1', axis_size=5)
|
431
|
+
def faulty_init():
|
432
|
+
# Modify state to produce an incompatible shape
|
433
|
+
foo.c = bst.State(jnp.arange(3)) # Original expected shape is (4,)
|
434
|
+
|
435
|
+
faulty_init()
|
brainstate/nn/_collective_ops.py
CHANGED
@@ -250,9 +250,9 @@ def vmap_call_all_functions(
|
|
250
250
|
@set_module_as('brainstate.nn')
|
251
251
|
def init_all_states(
|
252
252
|
target: T,
|
253
|
-
init_args
|
254
|
-
init_kwargs: Dict[str, Any] | None = None,
|
253
|
+
*init_args,
|
255
254
|
node_to_exclude: Filter = None,
|
255
|
+
**init_kwargs,
|
256
256
|
) -> T:
|
257
257
|
"""
|
258
258
|
Initialize all states for the given target module and its submodules.
|
{brainstate-0.1.0.post20250315.dist-info → brainstate-0.1.0.post20250322.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250322
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -14,8 +14,8 @@ brainstate/augment/_autograd.py,sha256=hfDoa2HbkRn-InOS0yOcb6gEZ2DLNqtWA133P8-hv
|
|
14
14
|
brainstate/augment/_autograd_test.py,sha256=2wCC8aUcDp2IHgF7wr1GK5HwWfELXni5PpA-082azuU,44058
|
15
15
|
brainstate/augment/_eval_shape.py,sha256=jgsS197Nizehr9A2nGaQPE7NuNujhFhmR3J96hTicX8,3890
|
16
16
|
brainstate/augment/_eval_shape_test.py,sha256=LFOJx7CWltmRLXdGY175UebLwtEMz2CzJ_gLqMZsJTw,1393
|
17
|
-
brainstate/augment/_mapping.py,sha256=
|
18
|
-
brainstate/augment/_mapping_test.py,sha256
|
17
|
+
brainstate/augment/_mapping.py,sha256=ru4byvHCtEvmSBTC-DckgKqyJEbKxstPdJcz3uE5AZ8,43494
|
18
|
+
brainstate/augment/_mapping_test.py,sha256=Ax9-NjnCHPrvO_fPN24mOtl0NZugWg3AcNyxUYnHS1E,14709
|
19
19
|
brainstate/augment/_random.py,sha256=ikRzNoDDE2BkARajDsBhNlngCUrghzGSZUDmEGvVors,5386
|
20
20
|
brainstate/compile/__init__.py,sha256=fQtG316MLkeeu1Ssp54Kghw1PwbGK5gNq9yRVJu0wjA,1474
|
21
21
|
brainstate/compile/_ad_checkpoint.py,sha256=3wv-f89oo94XeWwRV5LcRot0Nz7xTk5_PdjEDyUMsoo,9394
|
@@ -54,8 +54,8 @@ brainstate/init/_random_inits_test.py,sha256=lBL2RQdBSZ88Zqz4IMdbHJMvDi7ooZq6caC
|
|
54
54
|
brainstate/init/_regular_inits.py,sha256=DmVMajugfyYFNUMzgFdDKMvbBu9hMWxkfDd-50uhoLg,3187
|
55
55
|
brainstate/init/_regular_inits_test.py,sha256=tJl4aOkclllJIfKzJTbc0cfYCw2SoBsx8_G123RnqbU,1842
|
56
56
|
brainstate/nn/__init__.py,sha256=ar1hDUYbSO6oadMpbuS9FWZvZB_iyFzM8CwMK-RNDzM,1823
|
57
|
-
brainstate/nn/_collective_ops.py,sha256=
|
58
|
-
brainstate/nn/_collective_ops_test.py,sha256=
|
57
|
+
brainstate/nn/_collective_ops.py,sha256=Hb8ruvqlhSA-XWJ66ReItyaIhxyWlSKGFZ9-EYMQ-mk,17457
|
58
|
+
brainstate/nn/_collective_ops_test.py,sha256=nloqrlf6M7H-mgvHmIARrKzMotp8khxEuYSMPvXM5J0,1375
|
59
59
|
brainstate/nn/_common.py,sha256=XQw0i0sH3Y_qUwHSMC7G9VQnDj-RuuTh1Ul-xRIPxxc,7136
|
60
60
|
brainstate/nn/_exp_euler.py,sha256=s-Z_cT_oYvCvE-OaXuUidIxQs3KOy1pzkx1lwtfPo00,3529
|
61
61
|
brainstate/nn/_exp_euler_test.py,sha256=kvPf009DMYtla2uedKVKrPTHDyMTBepjlfsk5vDHqhI,1240
|
@@ -121,8 +121,8 @@ brainstate/util/_pretty_table.py,sha256=NM_6VAW6oL9jojsK0-RkQGHnDzLy_fn_hgzl5R8o
|
|
121
121
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
122
122
|
brainstate/util/_struct.py,sha256=F5GfFURITAIYTwf17_xypkZU1wvoL4dUCviPnr_eCtw,17515
|
123
123
|
brainstate/util/filter.py,sha256=Zw0H42NwAi2P7dBr3ISv2VpkB5jqoWnV4Kpd61gq66o,14126
|
124
|
-
brainstate-0.1.0.
|
125
|
-
brainstate-0.1.0.
|
126
|
-
brainstate-0.1.0.
|
127
|
-
brainstate-0.1.0.
|
128
|
-
brainstate-0.1.0.
|
124
|
+
brainstate-0.1.0.post20250322.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
125
|
+
brainstate-0.1.0.post20250322.dist-info/METADATA,sha256=ixPeEJ2tz1Dti9D7RFHS-bDLCj1a5VtVnbo7GS82xaE,3689
|
126
|
+
brainstate-0.1.0.post20250322.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
127
|
+
brainstate-0.1.0.post20250322.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
128
|
+
brainstate-0.1.0.post20250322.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20250315.dist-info → brainstate-0.1.0.post20250322.dist-info}/top_level.txt
RENAMED
File without changes
|