brainstate 0.1.1__py2.py3-none-any.whl → 0.1.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +12 -9
  3. brainstate/_state.py +1 -1
  4. brainstate/augment/_autograd_test.py +132 -133
  5. brainstate/augment/_eval_shape_test.py +7 -9
  6. brainstate/augment/_mapping_test.py +75 -76
  7. brainstate/compile/_ad_checkpoint_test.py +6 -8
  8. brainstate/compile/_conditions_test.py +35 -36
  9. brainstate/compile/_error_if_test.py +10 -13
  10. brainstate/compile/_loop_collect_return_test.py +7 -9
  11. brainstate/compile/_loop_no_collection_test.py +7 -8
  12. brainstate/compile/_make_jaxpr.py +29 -14
  13. brainstate/compile/_make_jaxpr_test.py +20 -20
  14. brainstate/functional/_activations_test.py +61 -61
  15. brainstate/graph/_graph_node_test.py +16 -18
  16. brainstate/graph/_graph_operation_test.py +154 -156
  17. brainstate/init/_random_inits_test.py +20 -21
  18. brainstate/init/_regular_inits_test.py +4 -5
  19. brainstate/mixin.py +1 -14
  20. brainstate/nn/__init__.py +81 -17
  21. brainstate/nn/_collective_ops_test.py +8 -8
  22. brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
  23. brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +31 -33
  24. brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +6 -2
  25. brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +1 -1
  26. brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +9 -9
  27. brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +139 -18
  28. brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +14 -15
  29. brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
  30. brainstate/nn/_elementwise_test.py +169 -0
  31. brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
  32. brainstate/nn/_exp_euler_test.py +5 -6
  33. brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob_mv.py} +1 -1
  34. brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_mv_test.py} +0 -1
  35. brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
  36. brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
  37. brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +1 -1
  38. brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -1
  39. brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +15 -17
  40. brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
  41. brainstate/nn/_module_test.py +34 -37
  42. brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
  43. brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +18 -19
  44. brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
  45. brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +10 -12
  46. brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
  47. brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +20 -22
  48. brainstate/nn/{_dynamics/_projection_base.py → _projection.py} +35 -3
  49. brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
  50. brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +6 -7
  51. brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
  52. brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +9 -10
  53. brainstate/nn/_stp.py +236 -0
  54. brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +17 -206
  55. brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +9 -10
  56. brainstate/nn/_synaptic_projection.py +133 -0
  57. brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
  58. brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +4 -5
  59. brainstate/optim/_lr_scheduler_test.py +3 -3
  60. brainstate/optim/_optax_optimizer_test.py +8 -9
  61. brainstate/random/_rand_funs_test.py +183 -184
  62. brainstate/random/_rand_seed_test.py +10 -12
  63. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/METADATA +1 -1
  64. brainstate-0.1.3.dist-info/RECORD +131 -0
  65. brainstate/nn/_dyn_impl/__init__.py +0 -42
  66. brainstate/nn/_dynamics/__init__.py +0 -37
  67. brainstate/nn/_elementwise/__init__.py +0 -22
  68. brainstate/nn/_elementwise/_elementwise_test.py +0 -171
  69. brainstate/nn/_interaction/__init__.py +0 -41
  70. brainstate-0.1.1.dist-info/RECORD +0 -133
  71. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/LICENSE +0 -0
  72. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/WHEEL +0 -0
  73. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
@@ -21,7 +20,7 @@ import brainunit as u
21
20
  import jax.numpy as jnp
22
21
  import pytest
23
22
 
24
- import brainstate as bst
23
+ import brainstate
25
24
  from brainstate.nn import Expon, STP, STD
26
25
 
27
26
 
@@ -32,7 +31,7 @@ class TestSynapse(unittest.TestCase):
32
31
  self.time_steps = 100
33
32
 
34
33
  def generate_input(self):
35
- return bst.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mS
34
+ return brainstate.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mS
36
35
 
37
36
  def test_expon_synapse(self):
38
37
  tau = 20.0 * u.ms
@@ -46,8 +45,8 @@ class TestSynapse(unittest.TestCase):
46
45
 
47
46
  # Test forward pass
48
47
  state = synapse.init_state(self.batch_size)
49
- call = bst.compile.jit(synapse)
50
- with bst.environ.context(dt=0.1 * u.ms):
48
+ call = brainstate.compile.jit(synapse)
49
+ with brainstate.environ.context(dt=0.1 * u.ms):
51
50
  for t in range(self.time_steps):
52
51
  out = call(inputs[t])
53
52
  self.assertEqual(out.shape, (self.batch_size, self.in_size))
@@ -75,7 +74,7 @@ class TestSynapse(unittest.TestCase):
75
74
 
76
75
  # Test forward pass
77
76
  state = synapse.init_state(self.batch_size)
78
- call = bst.compile.jit(synapse)
77
+ call = brainstate.compile.jit(synapse)
79
78
  for t in range(self.time_steps):
80
79
  out = call(inputs[t])
81
80
  self.assertEqual(out.shape, (self.batch_size, self.in_size))
@@ -118,15 +117,15 @@ class TestSynapse(unittest.TestCase):
118
117
  self.assertEqual(synapse.in_size, in_size)
119
118
  self.assertEqual(synapse.out_size, in_size)
120
119
 
121
- inputs = bst.random.randn(self.time_steps, self.batch_size, *in_size) * u.mS
120
+ inputs = brainstate.random.randn(self.time_steps, self.batch_size, *in_size) * u.mS
122
121
  state = synapse.init_state(self.batch_size)
123
- call = bst.compile.jit(synapse)
124
- with bst.environ.context(dt=0.1 * u.ms):
122
+ call = brainstate.compile.jit(synapse)
123
+ with brainstate.environ.context(dt=0.1 * u.ms):
125
124
  for t in range(self.time_steps):
126
125
  out = call(inputs[t])
127
126
  self.assertEqual(out.shape, (self.batch_size, *in_size))
128
127
 
129
128
 
130
129
  if __name__ == '__main__':
131
- with bst.environ.context(dt=0.1):
130
+ with brainstate.environ.context(dt=0.1):
132
131
  unittest.main()
@@ -0,0 +1,133 @@
1
+ # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # -*- coding: utf-8 -*-
16
+
17
+
18
+ from typing import Callable, Union
19
+
20
+ import brainunit as u
21
+
22
+ from brainstate._compatible_import import brainevent
23
+ from brainstate.mixin import ParamDescriber, AlignPost, UpdateReturn
24
+ from ._dynamics import Dynamics, Projection
25
+ from ._projection import AlignPostProj, RawProj
26
+ from ._stp import ShortTermPlasticity
27
+ from ._synapse import Synapse
28
+ from ._synouts import SynOut
29
+
30
+ __all__ = [
31
+ 'align_pre_projection',
32
+ 'align_post_projection',
33
+ ]
34
+
35
+
36
+ class align_pre_projection(Projection):
37
+ """
38
+ Represents a pre-synaptic alignment projection mechanism.
39
+
40
+ This class inherits from the `Projection` base class and is designed to
41
+ manage the pre-synaptic alignment process in neural network simulations.
42
+ It takes into account pre-synaptic dynamics, synaptic properties, delays,
43
+ communication functions, synaptic outputs, post-synaptic dynamics, and
44
+ short-term plasticity.
45
+
46
+ Attributes:
47
+ pre (Dynamics): The pre-synaptic dynamics object.
48
+ syn (Synapse): The synaptic object after pre-synaptic alignment.
49
+ delay (u.Quantity[u.second]): The output delay from the synapse.
50
+ projection (RawProj): The raw projection object handling communication,
51
+ output, and post-synaptic dynamics.
52
+ stp (ShortTermPlasticity, optional): The short-term plasticity object,
53
+ defaults to None.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ pre: Dynamics,
59
+ syn: Union[Synapse, ParamDescriber[Synapse]],
60
+ delay: u.Quantity[u.second] | None,
61
+ comm: Callable,
62
+ out: SynOut,
63
+ post: Dynamics,
64
+ stp: ShortTermPlasticity = None,
65
+ ):
66
+ super().__init__()
67
+ pre = pre
68
+ syn: Synapse = pre.align_pre(syn)
69
+ assert isinstance(syn, UpdateReturn), "Synapse must implement UpdateReturn interface"
70
+ # require "syn" implement the "update_return()" function
71
+ self.delay = syn.output_delay(delay)
72
+ self.projection = RawProj(comm=comm, out=out, post=post)
73
+ self.stp = stp
74
+
75
+ def update(self):
76
+ x = self.delay()
77
+ if self.stp is not None:
78
+ x = self.stp(x)
79
+ return self.projection(x)
80
+
81
+
82
+ class align_post_projection(Projection):
83
+ """
84
+ Represents a post-synaptic alignment projection mechanism.
85
+
86
+ This class inherits from the `Projection` base class and is designed to
87
+ manage the post-synaptic alignment process in neural network simulations.
88
+ It takes into account spike generators, communication functions, synaptic
89
+ properties, synaptic outputs, post-synaptic dynamics, and short-term plasticity.
90
+
91
+ Args:
92
+ *spike_generator: Callable(s) that generate spike events or transform input spikes.
93
+ comm (Callable): Communication function for the projection.
94
+ syn (Union[AlignPost, ParamDescriber[AlignPost]]): The post-synaptic alignment object or its parameter describer.
95
+ out (Union[SynOut, ParamDescriber[SynOut]]): The synaptic output object or its parameter describer.
96
+ post (Dynamics): The post-synaptic dynamics object.
97
+ stp (ShortTermPlasticity, optional): The short-term plasticity object, defaults to None.
98
+
99
+ """
100
+ def __init__(
101
+ self,
102
+ *spike_generator,
103
+ comm: Callable,
104
+ syn: Union[AlignPost, ParamDescriber[AlignPost]],
105
+ out: Union[SynOut, ParamDescriber[SynOut]],
106
+ post: Dynamics,
107
+ stp: ShortTermPlasticity = None,
108
+ ):
109
+ super().__init__()
110
+ self.spike_generator = spike_generator
111
+ self.projection = AlignPostProj(comm=comm, syn=syn, out=out, post=post)
112
+ self.stp = stp
113
+
114
+ def update(self, *x):
115
+ for fun in self.spike_generator:
116
+ x = fun(*x)
117
+ if isinstance(x, (tuple, list)):
118
+ x = tuple(x)
119
+ else:
120
+ x = (x,)
121
+ assert len(x) == 1, "Spike generator must return a single value or a tuple/list of values"
122
+ x = brainevent.BinaryArray(x[0]) # Ensure input is a BinaryFloat for spike generation
123
+ if self.stp is not None:
124
+ x = brainevent.MaskedFloat(self.stp(x)) # Ensure STP output is a MaskedFloat
125
+ return self.projection(x)
126
+
127
+
128
+ class align_pre_ltp(Projection):
129
+ pass
130
+
131
+
132
+ class align_post_ltp(Projection):
133
+ pass
@@ -19,8 +19,8 @@ import brainunit as u
19
19
  import jax.numpy as jnp
20
20
 
21
21
  from brainstate.mixin import BindCondData
22
- from brainstate.nn._module import Module
23
22
  from brainstate.typing import ArrayLike
23
+ from ._module import Module
24
24
 
25
25
  __all__ = [
26
26
  'SynOut', 'COBA', 'CUBA', 'MgBlock',
@@ -47,6 +47,9 @@ class SynOut(Module, BindCondData):
47
47
  ret = self.update(self._conductance, *args, **kwargs)
48
48
  return ret
49
49
 
50
+ def update(self, conductance, potential):
51
+ raise NotImplementedError
52
+
50
53
 
51
54
  class COBA(SynOut):
52
55
  r"""
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
@@ -21,7 +20,7 @@ import brainunit as u
21
20
  import jax.numpy as jnp
22
21
  import numpy as np
23
22
 
24
- import brainstate as bst
23
+ import brainstate
25
24
 
26
25
 
27
26
  class TestSynOutModels(unittest.TestCase):
@@ -35,19 +34,19 @@ class TestSynOutModels(unittest.TestCase):
35
34
  self.V_offset = jnp.array([0.0])
36
35
 
37
36
  def test_COBA(self):
38
- model = bst.nn.COBA(E=self.E)
37
+ model = brainstate.nn.COBA(E=self.E)
39
38
  output = model.update(self.conductance, self.potential)
40
39
  expected_output = self.conductance * (self.E - self.potential)
41
40
  np.testing.assert_array_almost_equal(output, expected_output)
42
41
 
43
42
  def test_CUBA(self):
44
- model = bst.nn.CUBA()
43
+ model = brainstate.nn.CUBA()
45
44
  output = model.update(self.conductance)
46
45
  expected_output = self.conductance * model.scale
47
46
  self.assertTrue(u.math.allclose(output, expected_output))
48
47
 
49
48
  def test_MgBlock(self):
50
- model = bst.nn.MgBlock(E=self.E, cc_Mg=self.cc_Mg, alpha=self.alpha, beta=self.beta, V_offset=self.V_offset)
49
+ model = brainstate.nn.MgBlock(E=self.E, cc_Mg=self.cc_Mg, alpha=self.alpha, beta=self.beta, V_offset=self.V_offset)
51
50
  output = model.update(self.conductance, self.potential)
52
51
  norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - self.potential)))
53
52
  expected_output = self.conductance * (self.E - self.potential) / norm
@@ -19,12 +19,12 @@ import unittest
19
19
 
20
20
  import jax.numpy as jnp
21
21
 
22
- import brainstate as bst
22
+ import brainstate
23
23
 
24
24
 
25
25
  class TestMultiStepLR(unittest.TestCase):
26
26
  def test1(self):
27
- lr = bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)
27
+ lr = brainstate.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)
28
28
  for i in range(40):
29
29
  r = lr(i)
30
30
  if i < 10:
@@ -37,7 +37,7 @@ class TestMultiStepLR(unittest.TestCase):
37
37
  self.assertTrue(jnp.allclose(r, 0.0001))
38
38
 
39
39
  def test2(self):
40
- lr = bst.compile.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
40
+ lr = brainstate.compile.jit(brainstate.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
41
41
  for i in range(40):
42
42
  r = lr(i)
43
43
  if i < 10:
@@ -13,39 +13,38 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
20
19
  import jax
21
20
  import optax
22
21
 
23
- import brainstate as bst
22
+ import brainstate
24
23
 
25
24
 
26
25
  class TestOptaxOptimizer(unittest.TestCase):
27
26
  def test1(self):
28
- class Model(bst.nn.Module):
27
+ class Model(brainstate.nn.Module):
29
28
  def __init__(self):
30
29
  super().__init__()
31
- self.linear1 = bst.nn.Linear(2, 3)
32
- self.linear2 = bst.nn.Linear(3, 4)
30
+ self.linear1 = brainstate.nn.Linear(2, 3)
31
+ self.linear2 = brainstate.nn.Linear(3, 4)
33
32
 
34
33
  def __call__(self, x):
35
34
  return self.linear2(self.linear1(x))
36
35
 
37
- x = bst.random.randn(1, 2)
36
+ x = brainstate.random.randn(1, 2)
38
37
  y = jax.numpy.ones((1, 4))
39
38
 
40
39
  model = Model()
41
40
  tx = optax.adam(1e-3)
42
- optimizer = bst.optim.OptaxOptimizer(tx)
43
- optimizer.register_trainable_weights(model.states(bst.ParamState))
41
+ optimizer = brainstate.optim.OptaxOptimizer(tx)
42
+ optimizer.register_trainable_weights(model.states(brainstate.ParamState))
44
43
 
45
44
  loss_fn = lambda: ((model(x) - y) ** 2).mean()
46
45
  prev_loss = loss_fn()
47
46
 
48
- grads = bst.augment.grad(loss_fn, model.states(bst.ParamState))()
47
+ grads = brainstate.augment.grad(loss_fn, model.states(brainstate.ParamState))()
49
48
  optimizer.update(grads)
50
49
 
51
50
  new_loss = loss_fn()