brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_delay.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -24,8 +24,8 @@ import numpy as np
24
24
 
25
25
  from brainstate import environ
26
26
  from brainstate._state import ShortTermState, State
27
- from brainstate.compile import jit_error_if
28
27
  from brainstate.graph import Node
28
+ from brainstate.transform import jit_error_if
29
29
  from brainstate.typing import ArrayLike, PyTree
30
30
  from ._collective_ops import call_order
31
31
  from ._module import Module
@@ -59,8 +59,8 @@ class DelayAccess(Node):
59
59
 
60
60
  Args:
61
61
  delay: The delay instance.
62
- time: The delay time.
63
- delay_entry: The delay entry.
62
+ *time: The delay time.
63
+ entry: The delay entry.
64
64
  """
65
65
 
66
66
  __module__ = 'brainstate.nn'
@@ -68,22 +68,22 @@ class DelayAccess(Node):
68
68
  def __init__(
69
69
  self,
70
70
  delay: 'Delay',
71
- time: Union[None, int, float],
72
- delay_entry: str,
71
+ *time,
72
+ entry: str,
73
73
  ):
74
74
  super().__init__()
75
- self.refs = {'delay': delay}
75
+ self.delay = delay
76
76
  assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
77
- self._delay_entry = delay_entry
78
- self.delay_info = delay.register_entry(self._delay_entry, time)
77
+ self._delay_entry = entry
78
+ self.delay_info = delay.register_entry(self._delay_entry, *time)
79
79
 
80
80
  def update(self):
81
- return self.refs['delay'].at(self._delay_entry)
81
+ return self.delay.at(self._delay_entry)
82
82
 
83
83
 
84
84
  class Delay(Module):
85
85
  """
86
- Generate Delays for the given :py:class:`~.State` instance.
86
+ Delay variable for storing short-term history data.
87
87
 
88
88
  The data in this delay variable is arranged as::
89
89
 
@@ -240,21 +240,8 @@ class Delay(Module):
240
240
  >>> delay_obj.register_delay(jnp.array([2.0, 3.0]), 0, 1) # Vector delay with indices
241
241
  """
242
242
  assert len(delay_time) >= 1, 'You should provide at least one delay time.'
243
- delay_size = u.math.size(delay_time[0])
244
243
  for dt in delay_time[1:]:
245
244
  assert jnp.issubdtype(u.math.get_dtype(dt), jnp.integer), f'The index should be integer. But got {dt}.'
246
- # delay_size = u.math.size(delay_time[0])
247
- # for dt in delay_time:
248
- # if u.math.ndim(dt) == 0:
249
- # pass
250
- # elif u.math.ndim(dt) == 1:
251
- # if u.math.size(dt) != delay_size:
252
- # raise ValueError(
253
- # f'The delay time should be a scalar or a vector with the same size. '
254
- # f'But got {delay_time}. The delay time {dt} has size {u.math.size(dt)}'
255
- # )
256
- # else:
257
- # raise ValueError(f'The delay time should be a scalar/vector. But got {dt}.')
258
245
  if delay_time[0] is None:
259
246
  return None
260
247
  with jax.ensure_compile_time_eval():
@@ -287,7 +274,7 @@ class Delay(Module):
287
274
  self._registered_entries[entry] = delay_info
288
275
  return delay_info
289
276
 
290
- def access(self, entry: str, delay_time: Sequence) -> DelayAccess:
277
+ def access(self, entry: str, *delay_time) -> DelayAccess:
291
278
  """
292
279
  Create a DelayAccess object for a specific delay entry and delay time.
293
280
 
@@ -298,7 +285,7 @@ class Delay(Module):
298
285
  Returns:
299
286
  DelayAccess: An object that provides access to the delay data for the specified entry and time.
300
287
  """
301
- return DelayAccess(self, delay_time, delay_entry=entry)
288
+ return DelayAccess(self, delay_time, entry=entry)
302
289
 
303
290
  def at(self, entry: str) -> ArrayLike:
304
291
  """
@@ -532,8 +519,8 @@ class StateWithDelay(Delay):
532
519
  Access a neuron's membrane potential 5 ms in the past:
533
520
 
534
521
  >>> import brainunit as u
535
- >>> import brainstate as bst
536
- >>> lif = bst.nn.LIF(100)
522
+ >>> import brainstate as brainstate
523
+ >>> lif = brainstate.nn.LIF(100)
537
524
  >>> # Create a delayed accessor to V(t-5ms)
538
525
  >>> v_delay = lif.prefetch_delay('V', 5.0 * u.ms)
539
526
  >>> # Inside another module's update you can read the delayed value
@@ -1,4 +1,4 @@
1
- # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -20,10 +20,14 @@ import jax.numpy as jnp
20
20
 
21
21
  import brainstate
22
22
 
23
- brainstate.environ.set(dt=0.1)
24
-
25
23
 
26
24
  class TestDelay(unittest.TestCase):
25
+ def setUp(self):
26
+ brainstate.environ.set(dt=0.1)
27
+
28
+ def tearDown(self):
29
+ brainstate.environ.pop('dt')
30
+
27
31
  def test_delay1(self):
28
32
  a = brainstate.State(brainstate.random.random(10, 20))
29
33
  delay = brainstate.nn.Delay(a.value)
@@ -61,26 +65,27 @@ class TestDelay(unittest.TestCase):
61
65
  self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
62
66
 
63
67
  def test_concat_delay(self):
64
- rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
65
- t0 = 0.
66
- t1, n1 = 1., 10
67
- t2, n2 = 2., 20
68
+ with brainstate.environ.context(dt=0.1) as env:
69
+ rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
70
+ t0 = 0.
71
+ t1, n1 = 1., 10
72
+ t2, n2 = 2., 20
68
73
 
69
- rotation_delay.register_entry('a', t0)
70
- rotation_delay.register_entry('b', t1)
71
- rotation_delay.register_entry('c', t2)
74
+ rotation_delay.register_entry('a', t0)
75
+ rotation_delay.register_entry('b', t1)
76
+ rotation_delay.register_entry('c', t2)
72
77
 
73
- rotation_delay.init_state()
78
+ rotation_delay.init_state()
74
79
 
75
- print()
76
- for i in range(100):
77
- brainstate.environ.set(i=i)
78
- rotation_delay.update(jnp.ones((1,)) * i)
79
- print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
80
- self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
81
- self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
82
- self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
83
- # brainstate.util.clear_buffer_memory()
80
+ print()
81
+ for i in range(100):
82
+ brainstate.environ.set(i=i)
83
+ rotation_delay.update(jnp.ones((1,)) * i)
84
+ print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
85
+ self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
86
+ self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
87
+ self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
88
+ # brainstate.util.clear_buffer_memory()
84
89
 
85
90
  def test_jit_erro(self):
86
91
  rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')