brainstate 0.1.0.post20250315__py2.py3-none-any.whl → 0.1.0.post20250325__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/augment/_mapping.py +3 -0
- brainstate/augment/_mapping_test.py +35 -1
- brainstate/nn/_collective_ops.py +7 -7
- brainstate/nn/_collective_ops_test.py +11 -0
- {brainstate-0.1.0.post20250315.dist-info → brainstate-0.1.0.post20250325.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20250315.dist-info → brainstate-0.1.0.post20250325.dist-info}/RECORD +9 -9
- {brainstate-0.1.0.post20250315.dist-info → brainstate-0.1.0.post20250325.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250315.dist-info → brainstate-0.1.0.post20250325.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250315.dist-info → brainstate-0.1.0.post20250325.dist-info}/top_level.txt +0 -0
brainstate/augment/_mapping.py
CHANGED
@@ -952,6 +952,9 @@ def _vmap_new_states_transform(
|
|
952
952
|
out_states: Dict[int, Dict] | Any | None = None,
|
953
953
|
):
|
954
954
|
# TODO: How about nested call ``vmap_new_states``?
|
955
|
+
if isinstance(axis_size, int) and axis_size <= 0:
|
956
|
+
raise ValueError(f"axis_size must be greater than 0, got {axis_size}.")
|
957
|
+
|
955
958
|
|
956
959
|
@vmap(
|
957
960
|
in_axes=in_axes,
|
@@ -315,7 +315,6 @@ class TestVmap(unittest.TestCase):
|
|
315
315
|
)
|
316
316
|
|
317
317
|
|
318
|
-
|
319
318
|
class TestMap(unittest.TestCase):
|
320
319
|
def test_map(self):
|
321
320
|
for dim in [(10,), (10, 10), (10, 10, 10)]:
|
@@ -399,3 +398,38 @@ class TestRemoveAxis:
|
|
399
398
|
complex_array = jnp.array([[1 + 1j, 2 + 2j], [3 + 3j, 4 + 4j]])
|
400
399
|
complex_result = _remove_axis(complex_array, 0)
|
401
400
|
assert jnp.allclose(complex_result, jnp.array([1 + 1j, 2 + 2j]))
|
401
|
+
|
402
|
+
|
403
|
+
class TestVMAPNewStatesEdgeCases(unittest.TestCase):
|
404
|
+
|
405
|
+
def test_axis_size_zero(self):
|
406
|
+
foo = brainstate.nn.LIF(3)
|
407
|
+
# Testing that axis_size of 0 raises an error.
|
408
|
+
with self.assertRaises(ValueError):
|
409
|
+
@bst.augment.vmap_new_states(state_tag='new1', axis_size=0)
|
410
|
+
def faulty_init():
|
411
|
+
foo.init_state()
|
412
|
+
|
413
|
+
# Call the decorated function to trigger validation
|
414
|
+
faulty_init()
|
415
|
+
|
416
|
+
def test_axis_size_negative(self):
|
417
|
+
foo = brainstate.nn.LIF(3)
|
418
|
+
# Testing that a negative axis_size raises an error.
|
419
|
+
with self.assertRaises(ValueError):
|
420
|
+
@bst.augment.vmap_new_states(state_tag='new1', axis_size=-3)
|
421
|
+
def faulty_init():
|
422
|
+
foo.init_state()
|
423
|
+
|
424
|
+
faulty_init()
|
425
|
+
|
426
|
+
def test_incompatible_shapes(self):
|
427
|
+
foo = brainstate.nn.LIF(3)
|
428
|
+
# Simulate an incompatible shapes scenario:
|
429
|
+
# We intentionally assign a state with a different shape than expected.
|
430
|
+
@bst.augment.vmap_new_states(state_tag='new1', axis_size=5)
|
431
|
+
def faulty_init():
|
432
|
+
# Modify state to produce an incompatible shape
|
433
|
+
foo.c = bst.State(jnp.arange(3)) # Original expected shape is (4,)
|
434
|
+
|
435
|
+
faulty_init()
|
brainstate/nn/_collective_ops.py
CHANGED
@@ -16,9 +16,9 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
from collections import namedtuple
|
19
|
+
from typing import Callable, TypeVar, Tuple, Any, Dict
|
19
20
|
|
20
21
|
import jax
|
21
|
-
from typing import Callable, TypeVar, Tuple, Any, Dict
|
22
22
|
|
23
23
|
from brainstate._state import catch_new_states
|
24
24
|
from brainstate._utils import set_module_as
|
@@ -250,9 +250,9 @@ def vmap_call_all_functions(
|
|
250
250
|
@set_module_as('brainstate.nn')
|
251
251
|
def init_all_states(
|
252
252
|
target: T,
|
253
|
-
init_args
|
254
|
-
init_kwargs: Dict[str, Any] | None = None,
|
253
|
+
*init_args,
|
255
254
|
node_to_exclude: Filter = None,
|
255
|
+
**init_kwargs,
|
256
256
|
) -> T:
|
257
257
|
"""
|
258
258
|
Initialize all states for the given target module and its submodules.
|
@@ -289,12 +289,12 @@ def init_all_states(
|
|
289
289
|
@set_module_as('brainstate.nn')
|
290
290
|
def vmap_init_all_states(
|
291
291
|
target: T,
|
292
|
-
init_args: Tuple[Any, ...] | Any
|
293
|
-
init_kwargs: Dict[str, Any] | None = None,
|
292
|
+
*init_args: Tuple[Any, ...] | Any,
|
294
293
|
axis_size: int = None,
|
295
294
|
node_to_exclude: Filter = None,
|
296
295
|
state_to_exclude: Filter = None,
|
297
296
|
state_tag: str | None = None,
|
297
|
+
**init_kwargs: Dict[str, Any] | None
|
298
298
|
) -> T:
|
299
299
|
"""
|
300
300
|
Initialize all vmap states for the given target module.
|
@@ -342,8 +342,8 @@ def vmap_init_all_states(
|
|
342
342
|
def init_fn():
|
343
343
|
init_all_states(
|
344
344
|
target,
|
345
|
-
init_args
|
346
|
-
init_kwargs
|
345
|
+
*init_args,
|
346
|
+
**init_kwargs,
|
347
347
|
node_to_exclude=node_to_exclude,
|
348
348
|
)
|
349
349
|
return
|
{brainstate-0.1.0.post20250315.dist-info → brainstate-0.1.0.post20250325.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250325
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -14,8 +14,8 @@ brainstate/augment/_autograd.py,sha256=hfDoa2HbkRn-InOS0yOcb6gEZ2DLNqtWA133P8-hv
|
|
14
14
|
brainstate/augment/_autograd_test.py,sha256=2wCC8aUcDp2IHgF7wr1GK5HwWfELXni5PpA-082azuU,44058
|
15
15
|
brainstate/augment/_eval_shape.py,sha256=jgsS197Nizehr9A2nGaQPE7NuNujhFhmR3J96hTicX8,3890
|
16
16
|
brainstate/augment/_eval_shape_test.py,sha256=LFOJx7CWltmRLXdGY175UebLwtEMz2CzJ_gLqMZsJTw,1393
|
17
|
-
brainstate/augment/_mapping.py,sha256=
|
18
|
-
brainstate/augment/_mapping_test.py,sha256
|
17
|
+
brainstate/augment/_mapping.py,sha256=ru4byvHCtEvmSBTC-DckgKqyJEbKxstPdJcz3uE5AZ8,43494
|
18
|
+
brainstate/augment/_mapping_test.py,sha256=Ax9-NjnCHPrvO_fPN24mOtl0NZugWg3AcNyxUYnHS1E,14709
|
19
19
|
brainstate/augment/_random.py,sha256=ikRzNoDDE2BkARajDsBhNlngCUrghzGSZUDmEGvVors,5386
|
20
20
|
brainstate/compile/__init__.py,sha256=fQtG316MLkeeu1Ssp54Kghw1PwbGK5gNq9yRVJu0wjA,1474
|
21
21
|
brainstate/compile/_ad_checkpoint.py,sha256=3wv-f89oo94XeWwRV5LcRot0Nz7xTk5_PdjEDyUMsoo,9394
|
@@ -54,8 +54,8 @@ brainstate/init/_random_inits_test.py,sha256=lBL2RQdBSZ88Zqz4IMdbHJMvDi7ooZq6caC
|
|
54
54
|
brainstate/init/_regular_inits.py,sha256=DmVMajugfyYFNUMzgFdDKMvbBu9hMWxkfDd-50uhoLg,3187
|
55
55
|
brainstate/init/_regular_inits_test.py,sha256=tJl4aOkclllJIfKzJTbc0cfYCw2SoBsx8_G123RnqbU,1842
|
56
56
|
brainstate/nn/__init__.py,sha256=ar1hDUYbSO6oadMpbuS9FWZvZB_iyFzM8CwMK-RNDzM,1823
|
57
|
-
brainstate/nn/_collective_ops.py,sha256=
|
58
|
-
brainstate/nn/_collective_ops_test.py,sha256=
|
57
|
+
brainstate/nn/_collective_ops.py,sha256=v5deEfjCWylCk4bV0b3mLHjFUJ0L9YiGhJ7D_RuZwBE,17428
|
58
|
+
brainstate/nn/_collective_ops_test.py,sha256=nloqrlf6M7H-mgvHmIARrKzMotp8khxEuYSMPvXM5J0,1375
|
59
59
|
brainstate/nn/_common.py,sha256=XQw0i0sH3Y_qUwHSMC7G9VQnDj-RuuTh1Ul-xRIPxxc,7136
|
60
60
|
brainstate/nn/_exp_euler.py,sha256=s-Z_cT_oYvCvE-OaXuUidIxQs3KOy1pzkx1lwtfPo00,3529
|
61
61
|
brainstate/nn/_exp_euler_test.py,sha256=kvPf009DMYtla2uedKVKrPTHDyMTBepjlfsk5vDHqhI,1240
|
@@ -121,8 +121,8 @@ brainstate/util/_pretty_table.py,sha256=NM_6VAW6oL9jojsK0-RkQGHnDzLy_fn_hgzl5R8o
|
|
121
121
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
122
122
|
brainstate/util/_struct.py,sha256=F5GfFURITAIYTwf17_xypkZU1wvoL4dUCviPnr_eCtw,17515
|
123
123
|
brainstate/util/filter.py,sha256=Zw0H42NwAi2P7dBr3ISv2VpkB5jqoWnV4Kpd61gq66o,14126
|
124
|
-
brainstate-0.1.0.
|
125
|
-
brainstate-0.1.0.
|
126
|
-
brainstate-0.1.0.
|
127
|
-
brainstate-0.1.0.
|
128
|
-
brainstate-0.1.0.
|
124
|
+
brainstate-0.1.0.post20250325.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
125
|
+
brainstate-0.1.0.post20250325.dist-info/METADATA,sha256=Ouz5zN2P7bIG1OYkQRFOaSxWeL4OOjHRJ2cQKaqZc6E,3689
|
126
|
+
brainstate-0.1.0.post20250325.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
127
|
+
brainstate-0.1.0.post20250325.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
128
|
+
brainstate-0.1.0.post20250325.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20250315.dist-info → brainstate-0.1.0.post20250325.dist-info}/top_level.txt
RENAMED
File without changes
|