brainstate 0.1.0.post20250315__py2.py3-none-any.whl → 0.1.0.post20250322__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -952,6 +952,9 @@ def _vmap_new_states_transform(
952
952
  out_states: Dict[int, Dict] | Any | None = None,
953
953
  ):
954
954
  # TODO: How about nested call ``vmap_new_states``?
955
+ if isinstance(axis_size, int) and axis_size <= 0:
956
+ raise ValueError(f"axis_size must be greater than 0, got {axis_size}.")
957
+
955
958
 
956
959
  @vmap(
957
960
  in_axes=in_axes,
@@ -315,7 +315,6 @@ class TestVmap(unittest.TestCase):
315
315
  )
316
316
 
317
317
 
318
-
319
318
  class TestMap(unittest.TestCase):
320
319
  def test_map(self):
321
320
  for dim in [(10,), (10, 10), (10, 10, 10)]:
@@ -399,3 +398,38 @@ class TestRemoveAxis:
399
398
  complex_array = jnp.array([[1 + 1j, 2 + 2j], [3 + 3j, 4 + 4j]])
400
399
  complex_result = _remove_axis(complex_array, 0)
401
400
  assert jnp.allclose(complex_result, jnp.array([1 + 1j, 2 + 2j]))
401
+
402
+
403
+ class TestVMAPNewStatesEdgeCases(unittest.TestCase):
404
+
405
+ def test_axis_size_zero(self):
406
+ foo = brainstate.nn.LIF(3)
407
+ # Testing that axis_size of 0 raises an error.
408
+ with self.assertRaises(ValueError):
409
+ @bst.augment.vmap_new_states(state_tag='new1', axis_size=0)
410
+ def faulty_init():
411
+ foo.init_state()
412
+
413
+ # Call the decorated function to trigger validation
414
+ faulty_init()
415
+
416
+ def test_axis_size_negative(self):
417
+ foo = brainstate.nn.LIF(3)
418
+ # Testing that a negative axis_size raises an error.
419
+ with self.assertRaises(ValueError):
420
+ @bst.augment.vmap_new_states(state_tag='new1', axis_size=-3)
421
+ def faulty_init():
422
+ foo.init_state()
423
+
424
+ faulty_init()
425
+
426
+ def test_incompatible_shapes(self):
427
+ foo = brainstate.nn.LIF(3)
428
+ # Simulate an incompatible shapes scenario:
429
+ # We intentionally assign a state with a different shape than expected.
430
+ @bst.augment.vmap_new_states(state_tag='new1', axis_size=5)
431
+ def faulty_init():
432
+ # Modify state to produce an incompatible shape
433
+ foo.c = bst.State(jnp.arange(3)) # Original expected shape is (4,)
434
+
435
+ faulty_init()
@@ -250,9 +250,9 @@ def vmap_call_all_functions(
250
250
  @set_module_as('brainstate.nn')
251
251
  def init_all_states(
252
252
  target: T,
253
- init_args: Tuple[Any, ...] | Any = (),
254
- init_kwargs: Dict[str, Any] | None = None,
253
+ *init_args,
255
254
  node_to_exclude: Filter = None,
255
+ **init_kwargs,
256
256
  ) -> T:
257
257
  """
258
258
  Initialize all states for the given target module and its submodules.
@@ -34,3 +34,14 @@ class Test_vmap_init_all_states:
34
34
  print(gru)
35
35
 
36
36
  init()
37
+
38
+
39
+ class Test_init_all_states:
40
+ def test_init_all_states(self):
41
+ gru = bst.nn.GRUCell(1, 2)
42
+ bst.nn.init_all_states(gru, batch_size=10)
43
+ print(gru)
44
+
45
+
46
+
47
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20250315
3
+ Version: 0.1.0.post20250322
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -14,8 +14,8 @@ brainstate/augment/_autograd.py,sha256=hfDoa2HbkRn-InOS0yOcb6gEZ2DLNqtWA133P8-hv
14
14
  brainstate/augment/_autograd_test.py,sha256=2wCC8aUcDp2IHgF7wr1GK5HwWfELXni5PpA-082azuU,44058
15
15
  brainstate/augment/_eval_shape.py,sha256=jgsS197Nizehr9A2nGaQPE7NuNujhFhmR3J96hTicX8,3890
16
16
  brainstate/augment/_eval_shape_test.py,sha256=LFOJx7CWltmRLXdGY175UebLwtEMz2CzJ_gLqMZsJTw,1393
17
- brainstate/augment/_mapping.py,sha256=BPwpD7jX4xRNl4BdAsKGoF45MKbmEF9Lyyp11pJucIg,43356
18
- brainstate/augment/_mapping_test.py,sha256=-4HJXmJw_6SD9dQnHTBjgYVuq6VTVjz0xpc9v2CJVNw,13414
17
+ brainstate/augment/_mapping.py,sha256=ru4byvHCtEvmSBTC-DckgKqyJEbKxstPdJcz3uE5AZ8,43494
18
+ brainstate/augment/_mapping_test.py,sha256=Ax9-NjnCHPrvO_fPN24mOtl0NZugWg3AcNyxUYnHS1E,14709
19
19
  brainstate/augment/_random.py,sha256=ikRzNoDDE2BkARajDsBhNlngCUrghzGSZUDmEGvVors,5386
20
20
  brainstate/compile/__init__.py,sha256=fQtG316MLkeeu1Ssp54Kghw1PwbGK5gNq9yRVJu0wjA,1474
21
21
  brainstate/compile/_ad_checkpoint.py,sha256=3wv-f89oo94XeWwRV5LcRot0Nz7xTk5_PdjEDyUMsoo,9394
@@ -54,8 +54,8 @@ brainstate/init/_random_inits_test.py,sha256=lBL2RQdBSZ88Zqz4IMdbHJMvDi7ooZq6caC
54
54
  brainstate/init/_regular_inits.py,sha256=DmVMajugfyYFNUMzgFdDKMvbBu9hMWxkfDd-50uhoLg,3187
55
55
  brainstate/init/_regular_inits_test.py,sha256=tJl4aOkclllJIfKzJTbc0cfYCw2SoBsx8_G123RnqbU,1842
56
56
  brainstate/nn/__init__.py,sha256=ar1hDUYbSO6oadMpbuS9FWZvZB_iyFzM8CwMK-RNDzM,1823
57
- brainstate/nn/_collective_ops.py,sha256=NI9BT-908TbIlXLMjbWsPyI5YLZD_cCkSKGeOY-qO60,17512
58
- brainstate/nn/_collective_ops_test.py,sha256=yW7NNYsGFglFRFkqVlpGSY6WLnU-h8GlK6wCmG5jtRc,1189
57
+ brainstate/nn/_collective_ops.py,sha256=Hb8ruvqlhSA-XWJ66ReItyaIhxyWlSKGFZ9-EYMQ-mk,17457
58
+ brainstate/nn/_collective_ops_test.py,sha256=nloqrlf6M7H-mgvHmIARrKzMotp8khxEuYSMPvXM5J0,1375
59
59
  brainstate/nn/_common.py,sha256=XQw0i0sH3Y_qUwHSMC7G9VQnDj-RuuTh1Ul-xRIPxxc,7136
60
60
  brainstate/nn/_exp_euler.py,sha256=s-Z_cT_oYvCvE-OaXuUidIxQs3KOy1pzkx1lwtfPo00,3529
61
61
  brainstate/nn/_exp_euler_test.py,sha256=kvPf009DMYtla2uedKVKrPTHDyMTBepjlfsk5vDHqhI,1240
@@ -121,8 +121,8 @@ brainstate/util/_pretty_table.py,sha256=NM_6VAW6oL9jojsK0-RkQGHnDzLy_fn_hgzl5R8o
121
121
  brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
122
122
  brainstate/util/_struct.py,sha256=F5GfFURITAIYTwf17_xypkZU1wvoL4dUCviPnr_eCtw,17515
123
123
  brainstate/util/filter.py,sha256=Zw0H42NwAi2P7dBr3ISv2VpkB5jqoWnV4Kpd61gq66o,14126
124
- brainstate-0.1.0.post20250315.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
125
- brainstate-0.1.0.post20250315.dist-info/METADATA,sha256=8gOEdv6PiXBLr_gAvx70Yik7G7XxidMVPPLOLx3ndPc,3689
126
- brainstate-0.1.0.post20250315.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
127
- brainstate-0.1.0.post20250315.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
128
- brainstate-0.1.0.post20250315.dist-info/RECORD,,
124
+ brainstate-0.1.0.post20250322.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
125
+ brainstate-0.1.0.post20250322.dist-info/METADATA,sha256=ixPeEJ2tz1Dti9D7RFHS-bDLCj1a5VtVnbo7GS82xaE,3689
126
+ brainstate-0.1.0.post20250322.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
127
+ brainstate-0.1.0.post20250322.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
128
+ brainstate-0.1.0.post20250322.dist-info/RECORD,,