brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_delay.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -24,8 +24,8 @@ import numpy as np
24
24
 
25
25
  from brainstate import environ
26
26
  from brainstate._state import ShortTermState, State
27
- from brainstate.compile import jit_error_if
28
27
  from brainstate.graph import Node
28
+ from brainstate.transform import jit_error_if
29
29
  from brainstate.typing import ArrayLike, PyTree
30
30
  from ._collective_ops import call_order
31
31
  from ._module import Module
@@ -59,8 +59,8 @@ class DelayAccess(Node):
59
59
 
60
60
  Args:
61
61
  delay: The delay instance.
62
- time: The delay time.
63
- delay_entry: The delay entry.
62
+ *time: The delay time.
63
+ entry: The delay entry.
64
64
  """
65
65
 
66
66
  __module__ = 'brainstate.nn'
@@ -68,22 +68,22 @@ class DelayAccess(Node):
68
68
  def __init__(
69
69
  self,
70
70
  delay: 'Delay',
71
- time: Union[None, int, float],
72
- delay_entry: str,
71
+ *time,
72
+ entry: str,
73
73
  ):
74
74
  super().__init__()
75
- self.refs = {'delay': delay}
75
+ self.delay = delay
76
76
  assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
77
- self._delay_entry = delay_entry
78
- self.delay_info = delay.register_entry(self._delay_entry, time)
77
+ self._delay_entry = entry
78
+ self.delay_info = delay.register_entry(self._delay_entry, *time)
79
79
 
80
80
  def update(self):
81
- return self.refs['delay'].at(self._delay_entry)
81
+ return self.delay.at(self._delay_entry)
82
82
 
83
83
 
84
84
  class Delay(Module):
85
85
  """
86
- Generate Delays for the given :py:class:`~.State` instance.
86
+ Delay variable for storing short-term history data.
87
87
 
88
88
  The data in this delay variable is arranged as::
89
89
 
@@ -240,21 +240,8 @@ class Delay(Module):
240
240
  >>> delay_obj.register_delay(jnp.array([2.0, 3.0]), 0, 1) # Vector delay with indices
241
241
  """
242
242
  assert len(delay_time) >= 1, 'You should provide at least one delay time.'
243
- delay_size = u.math.size(delay_time[0])
244
243
  for dt in delay_time[1:]:
245
244
  assert jnp.issubdtype(u.math.get_dtype(dt), jnp.integer), f'The index should be integer. But got {dt}.'
246
- # delay_size = u.math.size(delay_time[0])
247
- # for dt in delay_time:
248
- # if u.math.ndim(dt) == 0:
249
- # pass
250
- # elif u.math.ndim(dt) == 1:
251
- # if u.math.size(dt) != delay_size:
252
- # raise ValueError(
253
- # f'The delay time should be a scalar or a vector with the same size. '
254
- # f'But got {delay_time}. The delay time {dt} has size {u.math.size(dt)}'
255
- # )
256
- # else:
257
- # raise ValueError(f'The delay time should be a scalar/vector. But got {dt}.')
258
245
  if delay_time[0] is None:
259
246
  return None
260
247
  with jax.ensure_compile_time_eval():
@@ -287,7 +274,7 @@ class Delay(Module):
287
274
  self._registered_entries[entry] = delay_info
288
275
  return delay_info
289
276
 
290
- def access(self, entry: str, delay_time: Sequence) -> DelayAccess:
277
+ def access(self, entry: str, *delay_time) -> DelayAccess:
291
278
  """
292
279
  Create a DelayAccess object for a specific delay entry and delay time.
293
280
 
@@ -298,7 +285,7 @@ class Delay(Module):
298
285
  Returns:
299
286
  DelayAccess: An object that provides access to the delay data for the specified entry and time.
300
287
  """
301
- return DelayAccess(self, delay_time, delay_entry=entry)
288
+ return DelayAccess(self, delay_time, entry=entry)
302
289
 
303
290
  def at(self, entry: str) -> ArrayLike:
304
291
  """
@@ -472,15 +459,94 @@ class Delay(Module):
472
459
 
473
460
  class StateWithDelay(Delay):
474
461
  """
475
- A ``State`` type that defines the state in a differential equation.
462
+ Delayed history buffer bound to a module state.
463
+
464
+ StateWithDelay is a specialized :py:class:`~.Delay` that attaches to a
465
+ concrete :py:class:`~brainstate._state.State` living on a target module
466
+ (for example a membrane potential ``V`` on a neuron). It automatically
467
+ maintains a rolling history of that state and exposes convenient helpers to
468
+ retrieve the value at a given delay either by step or by time.
469
+
470
+ In normal usage you rarely instantiate this class directly. It is created
471
+ implicitly when using the prefetch-delay helpers on a Dynamics module, e.g.:
472
+
473
+ - ``module.prefetch('V').delay.at(5.0 * u.ms)``
474
+ - ``module.prefetch_delay('V', 5.0 * u.ms)``
475
+
476
+ Both will construct a StateWithDelay bound to ``module.V`` under the hood
477
+ and register the requested delay, so you can retrieve the delayed value
478
+ inside your update rules.
479
+
480
+ Parameters
481
+ ----------
482
+ target : :py:class:`~brainstate.graph.Node`
483
+ The module object that owns the state to track.
484
+ item : str
485
+ The attribute name of the target state on ``target`` (must be a
486
+ :py:class:`~brainstate._state.State`).
487
+ init : Callable, optional
488
+ Optional initializer used to fill the history buffer before ``t0``
489
+ when delays request values from the past that hasn't been simulated yet.
490
+ The callable receives ``(shape, dtype)`` and must return an array.
491
+ If not provided, zeros are used. You may also pass a scalar/array
492
+ literal via the underlying Delay API when constructing manually.
493
+ delay_method : {"rotation", "concat"}, default "rotation"
494
+ Internal buffering strategy (inherits behavior from :py:class:`~.Delay`).
495
+ "rotation" keeps a ring buffer; "concat" shifts by concatenation.
496
+
497
+ Attributes
498
+ ----------
499
+ state : :py:class:`~brainstate._state.State`
500
+ The concrete state object being tracked.
501
+ history : :py:class:`~brainstate._state.ShortTermState`
502
+ Rolling time axis buffer with shape ``[length, *state.shape]``.
503
+ max_time : float
504
+ Maximum time span currently supported by the buffer.
505
+ max_length : int
506
+ Buffer length in steps (``ceil(max_time/dt)+1``).
507
+
508
+ Notes
509
+ -----
510
+ - This class inherits all retrieval utilities from :py:class:`~.Delay`:
511
+ use :py:meth:`retrieve_at_step` when you know the integer delay steps,
512
+ or :py:meth:`retrieve_at_time` for continuous-time queries with optional
513
+ linear/round interpolation.
514
+ - It is registered as an "after-update" hook on the owning Dynamics so the
515
+ buffer is updated automatically after each simulation step.
516
+
517
+ Examples
518
+ --------
519
+ Access a neuron's membrane potential 5 ms in the past:
520
+
521
+ >>> import brainunit as u
522
+ >>> import brainstate as brainstate
523
+ >>> lif = brainstate.nn.LIF(100)
524
+ >>> # Create a delayed accessor to V(t-5ms)
525
+ >>> v_delay = lif.prefetch_delay('V', 5.0 * u.ms)
526
+ >>> # Inside another module's update you can read the delayed value
527
+ >>> v_t_minus_5ms = v_delay()
528
+
529
+ Register multiple delay taps and index-specific delays:
530
+
531
+ >>> # Under the hood, a StateWithDelay is created and you can register
532
+ >>> # additional taps (in steps or time) via its Delay interface
533
+ >>> _ = lif.prefetch('V').delay.at(2.0 * u.ms) # additional delay
534
+ >>> # Direct access to buffer by steps (advanced)
535
+ >>> # lif._get_after_update('V-prefetch-delay').retrieve_at_step(3)
476
536
  """
477
537
 
478
538
  __module__ = 'brainstate.nn'
479
539
 
480
540
  state: State # state
481
541
 
482
- def __init__(self, target: Node, item: str, init: Callable = None):
483
- super().__init__(None, init=init)
542
+ def __init__(
543
+ self,
544
+ target: Node,
545
+ item: str,
546
+ init: Callable = None,
547
+ delay_method: Optional[str] = _DELAY_ROTATE,
548
+ ):
549
+ super().__init__(None, init=init, delay_method=delay_method)
484
550
 
485
551
  self._target = target
486
552
  self._target_term = item
@@ -1,4 +1,4 @@
1
- # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -20,10 +20,14 @@ import jax.numpy as jnp
20
20
 
21
21
  import brainstate
22
22
 
23
- brainstate.environ.set(dt=0.1)
24
-
25
23
 
26
24
  class TestDelay(unittest.TestCase):
25
+ def setUp(self):
26
+ brainstate.environ.set(dt=0.1)
27
+
28
+ def tearDown(self):
29
+ brainstate.environ.pop('dt')
30
+
27
31
  def test_delay1(self):
28
32
  a = brainstate.State(brainstate.random.random(10, 20))
29
33
  delay = brainstate.nn.Delay(a.value)
@@ -61,26 +65,27 @@ class TestDelay(unittest.TestCase):
61
65
  self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
62
66
 
63
67
  def test_concat_delay(self):
64
- rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
65
- t0 = 0.
66
- t1, n1 = 1., 10
67
- t2, n2 = 2., 20
68
+ with brainstate.environ.context(dt=0.1) as env:
69
+ rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
70
+ t0 = 0.
71
+ t1, n1 = 1., 10
72
+ t2, n2 = 2., 20
68
73
 
69
- rotation_delay.register_entry('a', t0)
70
- rotation_delay.register_entry('b', t1)
71
- rotation_delay.register_entry('c', t2)
74
+ rotation_delay.register_entry('a', t0)
75
+ rotation_delay.register_entry('b', t1)
76
+ rotation_delay.register_entry('c', t2)
72
77
 
73
- rotation_delay.init_state()
78
+ rotation_delay.init_state()
74
79
 
75
- print()
76
- for i in range(100):
77
- brainstate.environ.set(i=i)
78
- rotation_delay.update(jnp.ones((1,)) * i)
79
- print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
80
- self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
81
- self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
82
- self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
83
- # brainstate.util.clear_buffer_memory()
80
+ print()
81
+ for i in range(100):
82
+ brainstate.environ.set(i=i)
83
+ rotation_delay.update(jnp.ones((1,)) * i)
84
+ print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
85
+ self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
86
+ self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
87
+ self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
88
+ # brainstate.util.clear_buffer_memory()
84
89
 
85
90
  def test_jit_erro(self):
86
91
  rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')