brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +10 -3
  3. brainstate/_state.py +178 -178
  4. brainstate/_utils.py +0 -1
  5. brainstate/augment/_autograd.py +0 -2
  6. brainstate/augment/_autograd_test.py +132 -133
  7. brainstate/augment/_eval_shape.py +0 -2
  8. brainstate/augment/_eval_shape_test.py +7 -9
  9. brainstate/augment/_mapping.py +2 -3
  10. brainstate/augment/_mapping_test.py +75 -76
  11. brainstate/augment/_random.py +0 -2
  12. brainstate/compile/_ad_checkpoint.py +0 -2
  13. brainstate/compile/_ad_checkpoint_test.py +6 -8
  14. brainstate/compile/_conditions.py +0 -2
  15. brainstate/compile/_conditions_test.py +35 -36
  16. brainstate/compile/_error_if.py +0 -2
  17. brainstate/compile/_error_if_test.py +10 -13
  18. brainstate/compile/_jit.py +9 -8
  19. brainstate/compile/_loop_collect_return.py +0 -2
  20. brainstate/compile/_loop_collect_return_test.py +7 -9
  21. brainstate/compile/_loop_no_collection.py +0 -2
  22. brainstate/compile/_loop_no_collection_test.py +7 -8
  23. brainstate/compile/_make_jaxpr.py +30 -17
  24. brainstate/compile/_make_jaxpr_test.py +20 -20
  25. brainstate/compile/_progress_bar.py +0 -1
  26. brainstate/compile/_unvmap.py +0 -1
  27. brainstate/compile/_util.py +0 -2
  28. brainstate/environ.py +0 -2
  29. brainstate/functional/_activations.py +0 -2
  30. brainstate/functional/_activations_test.py +61 -61
  31. brainstate/functional/_normalization.py +0 -2
  32. brainstate/functional/_others.py +0 -2
  33. brainstate/functional/_spikes.py +0 -1
  34. brainstate/graph/_graph_node.py +1 -3
  35. brainstate/graph/_graph_node_test.py +16 -18
  36. brainstate/graph/_graph_operation.py +4 -2
  37. brainstate/graph/_graph_operation_test.py +154 -156
  38. brainstate/init/_base.py +0 -2
  39. brainstate/init/_generic.py +0 -1
  40. brainstate/init/_random_inits.py +0 -1
  41. brainstate/init/_random_inits_test.py +20 -21
  42. brainstate/init/_regular_inits.py +0 -2
  43. brainstate/init/_regular_inits_test.py +4 -5
  44. brainstate/mixin.py +0 -2
  45. brainstate/nn/_collective_ops.py +0 -3
  46. brainstate/nn/_collective_ops_test.py +8 -8
  47. brainstate/nn/_common.py +0 -2
  48. brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
  49. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
  50. brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
  51. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
  52. brainstate/nn/_dyn_impl/_inputs.py +0 -1
  53. brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
  54. brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
  55. brainstate/nn/_dyn_impl/_readout.py +0 -1
  56. brainstate/nn/_dyn_impl/_readout_test.py +9 -10
  57. brainstate/nn/_dynamics/_dynamics_base.py +0 -1
  58. brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
  59. brainstate/nn/_dynamics/_projection_base.py +0 -1
  60. brainstate/nn/_dynamics/_state_delay.py +0 -2
  61. brainstate/nn/_dynamics/_synouts.py +0 -2
  62. brainstate/nn/_dynamics/_synouts_test.py +4 -5
  63. brainstate/nn/_elementwise/_dropout.py +0 -2
  64. brainstate/nn/_elementwise/_dropout_test.py +9 -9
  65. brainstate/nn/_elementwise/_elementwise.py +0 -2
  66. brainstate/nn/_elementwise/_elementwise_test.py +57 -59
  67. brainstate/nn/_event/_fixedprob_mv.py +0 -1
  68. brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
  69. brainstate/nn/_event/_linear_mv.py +0 -2
  70. brainstate/nn/_event/_linear_mv_test.py +0 -1
  71. brainstate/nn/_exp_euler.py +0 -2
  72. brainstate/nn/_exp_euler_test.py +5 -6
  73. brainstate/nn/_interaction/_conv.py +0 -2
  74. brainstate/nn/_interaction/_conv_test.py +31 -33
  75. brainstate/nn/_interaction/_embedding.py +0 -1
  76. brainstate/nn/_interaction/_linear.py +0 -2
  77. brainstate/nn/_interaction/_linear_test.py +15 -17
  78. brainstate/nn/_interaction/_normalizations.py +0 -2
  79. brainstate/nn/_interaction/_normalizations_test.py +10 -12
  80. brainstate/nn/_interaction/_poolings.py +0 -2
  81. brainstate/nn/_interaction/_poolings_test.py +19 -21
  82. brainstate/nn/_module.py +0 -1
  83. brainstate/nn/_module_test.py +34 -37
  84. brainstate/nn/metrics.py +0 -2
  85. brainstate/optim/_base.py +0 -2
  86. brainstate/optim/_lr_scheduler.py +0 -1
  87. brainstate/optim/_lr_scheduler_test.py +3 -3
  88. brainstate/optim/_optax_optimizer.py +0 -2
  89. brainstate/optim/_optax_optimizer_test.py +8 -9
  90. brainstate/optim/_sgd_optimizer.py +0 -1
  91. brainstate/random/_rand_funs.py +0 -1
  92. brainstate/random/_rand_funs_test.py +183 -184
  93. brainstate/random/_rand_seed.py +0 -1
  94. brainstate/random/_rand_seed_test.py +10 -12
  95. brainstate/random/_rand_state.py +0 -1
  96. brainstate/surrogate.py +0 -1
  97. brainstate/typing.py +0 -2
  98. brainstate/util/_caller.py +4 -6
  99. brainstate/util/_others.py +0 -2
  100. brainstate/util/_pretty_pytree.py +201 -150
  101. brainstate/util/_pretty_repr.py +0 -2
  102. brainstate/util/_pretty_table.py +57 -3
  103. brainstate/util/_scaling.py +0 -2
  104. brainstate/util/_struct.py +0 -2
  105. brainstate/util/filter.py +0 -2
  106. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
  107. brainstate-0.1.2.dist-info/RECORD +133 -0
  108. brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
  109. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
  110. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
  111. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
brainstate/__init__.py CHANGED
@@ -17,7 +17,7 @@
17
17
  A ``State``-based Transformation System for Program Compilation and Augmentation
18
18
  """
19
19
 
20
- __version__ = "0.1.0"
20
+ __version__ = "0.1.2"
21
21
 
22
22
  from . import augment
23
23
  from . import compile
@@ -35,8 +35,9 @@ __all__ = [
35
35
  'safe_map',
36
36
  'safe_zip',
37
37
  'unzip2',
38
- 'unzip3',
39
38
  'wraps',
39
+ 'Device',
40
+ 'wrap_init',
40
41
  ]
41
42
 
42
43
  T = TypeVar("T")
@@ -44,10 +45,17 @@ T1 = TypeVar("T1")
44
45
  T2 = TypeVar("T2")
45
46
  T3 = TypeVar("T3")
46
47
 
48
+
49
+ from saiunit._compatible_import import wrap_init
47
50
  brainevent_installed = importlib.util.find_spec('brainevent') is not None
48
51
 
49
52
  from jax.core import get_aval, Tracer
50
53
 
54
+ if jax.__version_info__ < (0, 5, 0):
55
+ from jax.lib.xla_client import Device
56
+ else:
57
+ from jax import Device
58
+
51
59
  if jax.__version_info__ < (0, 4, 38):
52
60
  from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
53
61
  else:
@@ -84,8 +92,7 @@ else:
84
92
  return list(zip(*args))
85
93
 
86
94
 
87
- def unzip2(xys: Iterable[tuple[T1, T2]]
88
- ) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
95
+ def unzip2(xys: Iterable[tuple[T1, T2]]) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
89
96
  """Unzip sequence of length-2 tuples into two tuples."""
90
97
  # Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
91
98
  # is too permissive about inputs, and does not guarantee a length-2 output.
brainstate/_state.py CHANGED
@@ -133,183 +133,6 @@ def check_state_value_tree(val: bool = True) -> Generator[None, None, None]:
133
133
  TRACE_CONTEXT.tree_check.pop()
134
134
 
135
135
 
136
- class StateCatcher(PrettyObject):
137
- """
138
- The catcher to catch and manage new states.
139
-
140
- This class provides functionality to collect and tag new State objects.
141
- It ensures that each state is only added once and assigns a tag to each state.
142
-
143
- Attributes:
144
- state_tag (str): A string identifier used to tag the caught states.
145
- state_ids (set): A set of state IDs to ensure uniqueness.
146
- states (list): A list to store the caught State objects.
147
- """
148
-
149
- def __init__(
150
- self,
151
- state_tag: str,
152
- state_to_exclude: Filter = Nothing()
153
- ):
154
- """
155
- Initialize a new Catcher instance.
156
-
157
- Args:
158
- state_tag (str): The tag to be assigned to caught states.
159
- state_to_exclude (Filter, optional): A filter to exclude states from being caught.
160
- """
161
- if state_to_exclude is None:
162
- state_to_exclude = Nothing()
163
- self.state_to_exclude = state_to_exclude
164
- self.state_tag = state_tag
165
- self.state_ids = set()
166
- self.states = []
167
-
168
- def get_state_values(self) -> List[PyTree]:
169
- """
170
- Get the values of the caught states.
171
-
172
- Returns:
173
- list: A list of values of the caught states.
174
- """
175
- return [state.value for state in self.states]
176
-
177
- def get_states(self) -> List[State]:
178
- """
179
- Get the caught states.
180
-
181
- Returns:
182
- list: A list of the caught states.
183
- """
184
- return self.states
185
-
186
- def append(self, state: State):
187
- """
188
- Add a new state to the catcher if it hasn't been added before.
189
-
190
- This method adds the state to the internal list, records its ID,
191
- and assigns the catcher's tag to the state.
192
-
193
- Args:
194
- state (State): The State object to be added.
195
- """
196
- if self.state_to_exclude((), state):
197
- return
198
- if id(state) not in self.state_ids:
199
- self.state_ids.add(id(state))
200
- self.states.append(state)
201
- state.tag = self.state_tag
202
-
203
- def __iter__(self):
204
- """
205
- Allow iteration over the caught states.
206
-
207
- Returns:
208
- iterator: An iterator over the list of caught states.
209
- """
210
- return iter(self.states)
211
-
212
- def __len__(self):
213
- """
214
- Return the number of caught states.
215
-
216
- Returns:
217
- int: The number of caught states.
218
- """
219
- return len(self.states)
220
-
221
- def __getitem__(self, index):
222
- """
223
- Get a state by index.
224
-
225
- Args:
226
- index (int): The index of the state to retrieve.
227
-
228
- Returns:
229
- State: The state at the specified index.
230
- """
231
- return self.states[index]
232
-
233
- def clear(self):
234
- """
235
- Clear all caught states.
236
- """
237
- self.state_ids.clear()
238
- self.states.clear()
239
-
240
- def get_by_tag(self, tag: str):
241
- """
242
- Get all states with a specific tag.
243
-
244
- Args:
245
- tag (str): The tag to filter by.
246
-
247
- Returns:
248
- list: A list of states with the specified tag.
249
- """
250
- return [state for state in self.states if state.tag == tag]
251
-
252
- def remove(self, state: State):
253
- """
254
- Remove a specific state from the catcher.
255
-
256
- Args:
257
- state (State): The state to remove.
258
- """
259
- if id(state) in self.state_ids:
260
- self.state_ids.remove(id(state))
261
- self.states.remove(state)
262
-
263
- def __contains__(self, state: State):
264
- """
265
- Check if a state is in the catcher.
266
-
267
- Args:
268
- state (State): The state to check for.
269
-
270
- Returns:
271
- bool: True if the state is in the catcher, False otherwise.
272
- """
273
- return id(state) in self.state_ids
274
-
275
-
276
- @contextlib.contextmanager
277
- def catch_new_states(
278
- state_tag: str = None,
279
- state_to_exclude: Filter = Nothing()
280
- ) -> Generator[StateCatcher, None, None]:
281
- """
282
- A context manager that catches and tracks new states created within its scope.
283
-
284
- This function creates a new Catcher object and adds it to the TRACE_CONTEXT's
285
- new_state_catcher list. It allows for tracking and managing new states created
286
- within the context.
287
-
288
- Args:
289
- state_tag (str, optional): A string tag to associate with the caught states.
290
- Defaults to None.
291
- state_to_exclude (Filter, optional): A filter object to specify which states
292
- should be excluded from catching. Defaults to Nothing(), which excludes no states.
293
-
294
- Yields:
295
- Catcher: A Catcher object that can be used to access and manage the
296
- newly created states within the context.
297
-
298
- Example::
299
-
300
- with catch_new_states("my_tag") as catcher:
301
- # Create new states here
302
- # They will be caught and tagged with "my_tag"
303
- # Access caught states through catcher object
304
- """
305
- try:
306
- catcher = StateCatcher(state_tag=state_tag, state_to_exclude=state_to_exclude)
307
- TRACE_CONTEXT.new_state_catcher.append(catcher)
308
- yield catcher
309
- finally:
310
- TRACE_CONTEXT.new_state_catcher.pop()
311
-
312
-
313
136
  def maybe_state(val: Any) -> Any:
314
137
  """
315
138
  Extracts the value from a State object if given, otherwise returns the input value.
@@ -1226,7 +1049,7 @@ class StateTraceStack(Generic[A]):
1226
1049
  """
1227
1050
  if self._jax_trace_new_arg is not None:
1228
1051
  # internal use
1229
- state._value = jax.tree.map(lambda x: self._jax_trace_new_arg(shaped_abstractify(x)), state._value)
1052
+ state._value = jax.tree.map(self._jax_trace_new_arg, state._value)
1230
1053
 
1231
1054
  def __enter__(self) -> 'StateTraceStack':
1232
1055
  TRACE_CONTEXT.state_stack.append(self)
@@ -1660,3 +1483,180 @@ jax.tree_util.register_pytree_with_keys(
1660
1483
  _state_ref_unflatten, # type: ignore
1661
1484
  flatten_func=partial(_state_ref_flatten, with_keys=False), # type: ignore
1662
1485
  )
1486
+
1487
+
1488
+ class StateCatcher(PrettyObject):
1489
+ """
1490
+ The catcher to catch and manage new states.
1491
+
1492
+ This class provides functionality to collect and tag new State objects.
1493
+ It ensures that each state is only added once and assigns a tag to each state.
1494
+
1495
+ Attributes:
1496
+ state_tag (str): A string identifier used to tag the caught states.
1497
+ state_ids (set): A set of state IDs to ensure uniqueness.
1498
+ states (list): A list to store the caught State objects.
1499
+ """
1500
+
1501
+ def __init__(
1502
+ self,
1503
+ state_tag: str,
1504
+ state_to_exclude: Filter = Nothing()
1505
+ ):
1506
+ """
1507
+ Initialize a new Catcher instance.
1508
+
1509
+ Args:
1510
+ state_tag (str): The tag to be assigned to caught states.
1511
+ state_to_exclude (Filter, optional): A filter to exclude states from being caught.
1512
+ """
1513
+ if state_to_exclude is None:
1514
+ state_to_exclude = Nothing()
1515
+ self.state_to_exclude = state_to_exclude
1516
+ self.state_tag = state_tag
1517
+ self.state_ids = set()
1518
+ self.states = []
1519
+
1520
+ def get_state_values(self) -> List[PyTree]:
1521
+ """
1522
+ Get the values of the caught states.
1523
+
1524
+ Returns:
1525
+ list: A list of values of the caught states.
1526
+ """
1527
+ return [state.value for state in self.states]
1528
+
1529
+ def get_states(self) -> List[State]:
1530
+ """
1531
+ Get the caught states.
1532
+
1533
+ Returns:
1534
+ list: A list of the caught states.
1535
+ """
1536
+ return self.states
1537
+
1538
+ def append(self, state: State):
1539
+ """
1540
+ Add a new state to the catcher if it hasn't been added before.
1541
+
1542
+ This method adds the state to the internal list, records its ID,
1543
+ and assigns the catcher's tag to the state.
1544
+
1545
+ Args:
1546
+ state (State): The State object to be added.
1547
+ """
1548
+ if self.state_to_exclude((), state):
1549
+ return
1550
+ if id(state) not in self.state_ids:
1551
+ self.state_ids.add(id(state))
1552
+ self.states.append(state)
1553
+ state.tag = self.state_tag
1554
+
1555
+ def __iter__(self):
1556
+ """
1557
+ Allow iteration over the caught states.
1558
+
1559
+ Returns:
1560
+ iterator: An iterator over the list of caught states.
1561
+ """
1562
+ return iter(self.states)
1563
+
1564
+ def __len__(self):
1565
+ """
1566
+ Return the number of caught states.
1567
+
1568
+ Returns:
1569
+ int: The number of caught states.
1570
+ """
1571
+ return len(self.states)
1572
+
1573
+ def __getitem__(self, index):
1574
+ """
1575
+ Get a state by index.
1576
+
1577
+ Args:
1578
+ index (int): The index of the state to retrieve.
1579
+
1580
+ Returns:
1581
+ State: The state at the specified index.
1582
+ """
1583
+ return self.states[index]
1584
+
1585
+ def clear(self):
1586
+ """
1587
+ Clear all caught states.
1588
+ """
1589
+ self.state_ids.clear()
1590
+ self.states.clear()
1591
+
1592
+ def get_by_tag(self, tag: str):
1593
+ """
1594
+ Get all states with a specific tag.
1595
+
1596
+ Args:
1597
+ tag (str): The tag to filter by.
1598
+
1599
+ Returns:
1600
+ list: A list of states with the specified tag.
1601
+ """
1602
+ return [state for state in self.states if state.tag == tag]
1603
+
1604
+ def remove(self, state: State):
1605
+ """
1606
+ Remove a specific state from the catcher.
1607
+
1608
+ Args:
1609
+ state (State): The state to remove.
1610
+ """
1611
+ if id(state) in self.state_ids:
1612
+ self.state_ids.remove(id(state))
1613
+ self.states.remove(state)
1614
+
1615
+ def __contains__(self, state: State):
1616
+ """
1617
+ Check if a state is in the catcher.
1618
+
1619
+ Args:
1620
+ state (State): The state to check for.
1621
+
1622
+ Returns:
1623
+ bool: True if the state is in the catcher, False otherwise.
1624
+ """
1625
+ return id(state) in self.state_ids
1626
+
1627
+
1628
+ @contextlib.contextmanager
1629
+ def catch_new_states(
1630
+ state_tag: str = None,
1631
+ state_to_exclude: Filter = Nothing()
1632
+ ) -> Generator[StateCatcher, None, None]:
1633
+ """
1634
+ A context manager that catches and tracks new states created within its scope.
1635
+
1636
+ This function creates a new Catcher object and adds it to the TRACE_CONTEXT's
1637
+ new_state_catcher list. It allows for tracking and managing new states created
1638
+ within the context.
1639
+
1640
+ Args:
1641
+ state_tag (str, optional): A string tag to associate with the caught states.
1642
+ Defaults to None.
1643
+ state_to_exclude (Filter, optional): A filter object to specify which states
1644
+ should be excluded from catching. Defaults to Nothing(), which excludes no states.
1645
+
1646
+ Yields:
1647
+ Catcher: A Catcher object that can be used to access and manage the
1648
+ newly created states within the context.
1649
+
1650
+ Example::
1651
+
1652
+ with catch_new_states("my_tag") as catcher:
1653
+ # Create new states here
1654
+ # They will be caught and tagged with "my_tag"
1655
+ # Access caught states through catcher object
1656
+ """
1657
+ try:
1658
+ catcher = StateCatcher(state_tag=state_tag, state_to_exclude=state_to_exclude)
1659
+ TRACE_CONTEXT.new_state_catcher.append(catcher)
1660
+ yield catcher
1661
+ finally:
1662
+ TRACE_CONTEXT.new_state_catcher.pop()
brainstate/_utils.py CHANGED
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import warnings
19
18
 
@@ -27,8 +27,6 @@ The wrapped gradient transformations here are made possible by using the followi
27
27
 
28
28
  """
29
29
 
30
- from __future__ import annotations
31
-
32
30
  from functools import wraps, partial
33
31
  from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
34
32