brainstate 0.1.1__py2.py3-none-any.whl → 0.1.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +12 -9
  3. brainstate/_state.py +1 -1
  4. brainstate/augment/_autograd_test.py +132 -133
  5. brainstate/augment/_eval_shape_test.py +7 -9
  6. brainstate/augment/_mapping_test.py +75 -76
  7. brainstate/compile/_ad_checkpoint_test.py +6 -8
  8. brainstate/compile/_conditions_test.py +35 -36
  9. brainstate/compile/_error_if_test.py +10 -13
  10. brainstate/compile/_loop_collect_return_test.py +7 -9
  11. brainstate/compile/_loop_no_collection_test.py +7 -8
  12. brainstate/compile/_make_jaxpr.py +29 -14
  13. brainstate/compile/_make_jaxpr_test.py +20 -20
  14. brainstate/functional/_activations_test.py +61 -61
  15. brainstate/graph/_graph_node_test.py +16 -18
  16. brainstate/graph/_graph_operation_test.py +154 -156
  17. brainstate/init/_random_inits_test.py +20 -21
  18. brainstate/init/_regular_inits_test.py +4 -5
  19. brainstate/mixin.py +1 -14
  20. brainstate/nn/__init__.py +81 -17
  21. brainstate/nn/_collective_ops_test.py +8 -8
  22. brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
  23. brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +31 -33
  24. brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +6 -2
  25. brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +1 -1
  26. brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +9 -9
  27. brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +139 -18
  28. brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +14 -15
  29. brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
  30. brainstate/nn/_elementwise_test.py +169 -0
  31. brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
  32. brainstate/nn/_exp_euler_test.py +5 -6
  33. brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob_mv.py} +1 -1
  34. brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_mv_test.py} +0 -1
  35. brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
  36. brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
  37. brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +1 -1
  38. brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -1
  39. brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +15 -17
  40. brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
  41. brainstate/nn/_module_test.py +34 -37
  42. brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
  43. brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +18 -19
  44. brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
  45. brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +10 -12
  46. brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
  47. brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +20 -22
  48. brainstate/nn/{_dynamics/_projection_base.py → _projection.py} +35 -3
  49. brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
  50. brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +6 -7
  51. brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
  52. brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +9 -10
  53. brainstate/nn/_stp.py +236 -0
  54. brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +17 -206
  55. brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +9 -10
  56. brainstate/nn/_synaptic_projection.py +133 -0
  57. brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
  58. brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +4 -5
  59. brainstate/optim/_lr_scheduler_test.py +3 -3
  60. brainstate/optim/_optax_optimizer_test.py +8 -9
  61. brainstate/random/_rand_funs_test.py +183 -184
  62. brainstate/random/_rand_seed_test.py +10 -12
  63. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/METADATA +1 -1
  64. brainstate-0.1.3.dist-info/RECORD +131 -0
  65. brainstate/nn/_dyn_impl/__init__.py +0 -42
  66. brainstate/nn/_dynamics/__init__.py +0 -37
  67. brainstate/nn/_elementwise/__init__.py +0 -22
  68. brainstate/nn/_elementwise/_elementwise_test.py +0 -171
  69. brainstate/nn/_interaction/__init__.py +0 -41
  70. brainstate-0.1.1.dist-info/RECORD +0 -133
  71. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/LICENSE +0 -0
  72. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/WHEEL +0 -0
  73. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,10 @@ from typing import Union, Callable, Optional
17
17
 
18
18
  from brainstate._state import State
19
19
  from brainstate.mixin import AlignPost, ParamDescriber, BindCondData, JointTypes
20
- from brainstate.nn._collective_ops import call_order
21
- from brainstate.nn._module import Module
22
20
  from brainstate.util._others import get_unique_name
23
- from ._dynamics_base import Dynamics, maybe_init_prefetch, Prefetch, PrefetchDelayAt
21
+ from ._collective_ops import call_order
22
+ from ._dynamics import Dynamics, maybe_init_prefetch, Prefetch, PrefetchDelayAt
23
+ from ._module import Module
24
24
  from ._synouts import SynOut
25
25
 
26
26
  __all__ = [
@@ -360,3 +360,35 @@ class CurrentProj(Interaction):
360
360
  x = self.prefetch(*x)
361
361
  x = self.comm(x)
362
362
  self.out.bind_cond(x)
363
+
364
+
365
+ class RawProj(Interaction):
366
+ """
367
+ """
368
+ __module__ = 'brainstate.nn'
369
+
370
+ def __init__(
371
+ self,
372
+ comm: Callable,
373
+ out: SynOut,
374
+ post: Dynamics,
375
+ ):
376
+ super().__init__(name=get_unique_name(self.__class__.__name__))
377
+
378
+ # check out
379
+ if not isinstance(out, SynOut):
380
+ raise TypeError(f'The out should be a SynOut, but got {out}.')
381
+ self.out = out
382
+
383
+ # check post
384
+ if not isinstance(post, Dynamics):
385
+ raise TypeError(f'The post should be a Dynamics, but got {post}.')
386
+ self.post = post
387
+ post.add_current_input(self.name, out)
388
+
389
+ # output initialization
390
+ self.comm = comm
391
+
392
+ def update(self, x):
393
+ x = self.comm(x)
394
+ self.out.bind_cond(x)
@@ -22,9 +22,9 @@ import jax.numpy as jnp
22
22
 
23
23
  from brainstate import random, init, functional
24
24
  from brainstate._state import HiddenState, ParamState
25
- from brainstate.nn._interaction._linear import Linear
26
- from brainstate.nn._module import Module
27
25
  from brainstate.typing import ArrayLike
26
+ from ._linear import Linear
27
+ from ._module import Module
28
28
 
29
29
  __all__ = [
30
30
  'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
@@ -13,13 +13,12 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
20
19
  import jax.numpy as jnp
21
20
 
22
- import brainstate as bst
21
+ import brainstate
23
22
 
24
23
 
25
24
  class TestRateRNNModels(unittest.TestCase):
@@ -30,31 +29,31 @@ class TestRateRNNModels(unittest.TestCase):
30
29
  self.x = jnp.ones((self.batch_size, self.num_in))
31
30
 
32
31
  def test_ValinaRNNCell(self):
33
- model = bst.nn.ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
32
+ model = brainstate.nn.ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
34
33
  model.init_state(batch_size=self.batch_size)
35
34
  output = model.update(self.x)
36
35
  self.assertEqual(output.shape, (self.batch_size, self.num_out))
37
36
 
38
37
  def test_GRUCell(self):
39
- model = bst.nn.GRUCell(num_in=self.num_in, num_out=self.num_out)
38
+ model = brainstate.nn.GRUCell(num_in=self.num_in, num_out=self.num_out)
40
39
  model.init_state(batch_size=self.batch_size)
41
40
  output = model.update(self.x)
42
41
  self.assertEqual(output.shape, (self.batch_size, self.num_out))
43
42
 
44
43
  def test_MGUCell(self):
45
- model = bst.nn.MGUCell(num_in=self.num_in, num_out=self.num_out)
44
+ model = brainstate.nn.MGUCell(num_in=self.num_in, num_out=self.num_out)
46
45
  model.init_state(batch_size=self.batch_size)
47
46
  output = model.update(self.x)
48
47
  self.assertEqual(output.shape, (self.batch_size, self.num_out))
49
48
 
50
49
  def test_LSTMCell(self):
51
- model = bst.nn.LSTMCell(num_in=self.num_in, num_out=self.num_out)
50
+ model = brainstate.nn.LSTMCell(num_in=self.num_in, num_out=self.num_out)
52
51
  model.init_state(batch_size=self.batch_size)
53
52
  output = model.update(self.x)
54
53
  self.assertEqual(output.shape, (self.batch_size, self.num_out))
55
54
 
56
55
  def test_URLSTMCell(self):
57
- model = bst.nn.URLSTMCell(num_in=self.num_in, num_out=self.num_out)
56
+ model = brainstate.nn.URLSTMCell(num_in=self.num_in, num_out=self.num_out)
58
57
  model.init_state(batch_size=self.batch_size)
59
58
  output = model.update(self.x)
60
59
  self.assertEqual(output.shape, (self.batch_size, self.num_out))
@@ -24,10 +24,10 @@ import jax
24
24
 
25
25
  from brainstate import environ, init, surrogate
26
26
  from brainstate._state import HiddenState, ParamState
27
- from brainstate.nn._exp_euler import exp_euler_step
28
- from brainstate.nn._module import Module
29
27
  from brainstate.typing import Size, ArrayLike
30
- from ._dynamics_neuron import Neuron
28
+ from ._exp_euler import exp_euler_step
29
+ from ._module import Module
30
+ from ._neuron import Neuron
31
31
 
32
32
  __all__ = [
33
33
  'LeakyRateReadout',
@@ -13,13 +13,12 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
20
19
  import jax.numpy as jnp
21
20
 
22
- import brainstate as bst
21
+ import brainstate
23
22
 
24
23
 
25
24
  class TestReadoutModels(unittest.TestCase):
@@ -32,23 +31,23 @@ class TestReadoutModels(unittest.TestCase):
32
31
  self.x = jnp.ones((self.batch_size, self.in_size))
33
32
 
34
33
  def test_LeakyRateReadout(self):
35
- with bst.environ.context(dt=0.1):
36
- model = bst.nn.LeakyRateReadout(in_size=self.in_size, out_size=self.out_size, tau=self.tau)
34
+ with brainstate.environ.context(dt=0.1):
35
+ model = brainstate.nn.LeakyRateReadout(in_size=self.in_size, out_size=self.out_size, tau=self.tau)
37
36
  model.init_state(batch_size=self.batch_size)
38
37
  output = model.update(self.x)
39
38
  self.assertEqual(output.shape, (self.batch_size, self.out_size))
40
39
 
41
40
  def test_LeakySpikeReadout(self):
42
- with bst.environ.context(dt=0.1):
43
- model = bst.nn.LeakySpikeReadout(in_size=self.in_size, tau=self.tau, V_th=self.V_th,
44
- V_initializer=bst.init.ZeroInit(),
45
- w_init=bst.init.KaimingNormal())
41
+ with brainstate.environ.context(dt=0.1):
42
+ model = brainstate.nn.LeakySpikeReadout(in_size=self.in_size, tau=self.tau, V_th=self.V_th,
43
+ V_initializer=brainstate.init.ZeroInit(),
44
+ w_init=brainstate.init.KaimingNormal())
46
45
  model.init_state(batch_size=self.batch_size)
47
- with bst.environ.context(t=0.):
46
+ with brainstate.environ.context(t=0.):
48
47
  output = model.update(self.x)
49
48
  self.assertEqual(output.shape, (self.batch_size, self.out_size))
50
49
 
51
50
 
52
51
  if __name__ == '__main__':
53
- with bst.environ.context(dt=0.1):
52
+ with brainstate.environ.context(dt=0.1):
54
53
  unittest.main()
brainstate/nn/_stp.py ADDED
@@ -0,0 +1,236 @@
1
+ # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from typing import Optional
19
+
20
+ import brainunit as u
21
+
22
+ from brainstate import init
23
+ from brainstate._state import HiddenState
24
+ from brainstate.typing import ArrayLike, Size
25
+ from ._exp_euler import exp_euler_step
26
+ from ._synapse import Synapse
27
+
28
+ __all__ = [
29
+ 'ShortTermPlasticity', 'STP', 'STD',
30
+ ]
31
+
32
+
33
+ class ShortTermPlasticity(Synapse):
34
+ pass
35
+
36
+
37
+ class STP(ShortTermPlasticity):
38
+ r"""
39
+ Synapse with short-term plasticity.
40
+
41
+ This class implements a synapse model with short-term plasticity (STP), which captures
42
+ activity-dependent changes in synaptic efficacy that occur over milliseconds to seconds.
43
+ The model simultaneously accounts for both short-term facilitation and depression
44
+ based on the formulation by Tsodyks & Markram (1998).
45
+
46
+ The model is characterized by the following equations:
47
+
48
+ $$
49
+ \frac{du}{dt} = -\frac{u}{\tau_f} + U \cdot (1 - u) \cdot \delta(t - t_{spike})
50
+ $$
51
+
52
+ $$
53
+ \frac{dx}{dt} = \frac{1 - x}{\tau_d} - u \cdot x \cdot \delta(t - t_{spike})
54
+ $$
55
+
56
+ $$
57
+ g_{syn} = u \cdot x
58
+ $$
59
+
60
+ where:
61
+ - $u$ represents the utilization of synaptic efficacy (facilitation variable)
62
+ - $x$ represents the available synaptic resources (depression variable)
63
+ - $\tau_f$ is the facilitation time constant
64
+ - $\tau_d$ is the depression time constant
65
+ - $U$ is the baseline utilization parameter
66
+ - $\delta(t - t_{spike})$ is the Dirac delta function representing presynaptic spikes
67
+ - $g_{syn}$ is the effective synaptic conductance
68
+
69
+ Parameters
70
+ ----------
71
+ in_size : Size
72
+ Size of the input.
73
+ name : str, optional
74
+ Name of the synapse instance.
75
+ U : ArrayLike, default=0.15
76
+ Baseline utilization parameter (fraction of resources used per action potential).
77
+ tau_f : ArrayLike, default=1500.*u.ms
78
+ Time constant of short-term facilitation in milliseconds.
79
+ tau_d : ArrayLike, default=200.*u.ms
80
+ Time constant of short-term depression (recovery of synaptic resources) in milliseconds.
81
+
82
+ Attributes
83
+ ----------
84
+ u : HiddenState
85
+ Utilization of synaptic efficacy (facilitation variable).
86
+ x : HiddenState
87
+ Available synaptic resources (depression variable).
88
+
89
+ Notes
90
+ -----
91
+ - Larger values of tau_f produce stronger facilitation effects.
92
+ - Larger values of tau_d lead to slower recovery from depression.
93
+ - The parameter U controls the initial release probability.
94
+ - The effective synaptic strength is the product of u and x.
95
+
96
+ References
97
+ ----------
98
+ .. [1] Tsodyks, M. V., & Markram, H. (1997). The neural code between neocortical
99
+ pyramidal neurons depends on neurotransmitter release probability.
100
+ Proceedings of the National Academy of Sciences, 94(2), 719-723.
101
+ .. [2] Tsodyks, M., Pawelzik, K., & Markram, H. (1998). Neural networks with dynamic
102
+ synapses. Neural computation, 10(4), 821-835.
103
+ """
104
+ __module__ = 'brainstate.nn'
105
+
106
+ def __init__(
107
+ self,
108
+ in_size: Size,
109
+ name: Optional[str] = None,
110
+ U: ArrayLike = 0.15,
111
+ tau_f: ArrayLike = 1500. * u.ms,
112
+ tau_d: ArrayLike = 200. * u.ms,
113
+ ):
114
+ super().__init__(name=name, in_size=in_size)
115
+
116
+ # parameters
117
+ self.tau_f = init.param(tau_f, self.varshape)
118
+ self.tau_d = init.param(tau_d, self.varshape)
119
+ self.U = init.param(U, self.varshape)
120
+
121
+ def init_state(self, batch_size: int = None, **kwargs):
122
+ self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
123
+ self.u = HiddenState(init.param(init.Constant(self.U), self.varshape, batch_size))
124
+
125
+ def reset_state(self, batch_size: int = None, **kwargs):
126
+ self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
127
+ self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
128
+
129
+ def update(self, pre_spike):
130
+ u = exp_euler_step(lambda u: - u / self.tau_f, self.u.value)
131
+ x = exp_euler_step(lambda x: (1 - x) / self.tau_d, self.x.value)
132
+
133
+ # --- original code:
134
+ # if pre_spike.dtype == jax.numpy.bool_:
135
+ # u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
136
+ # x = bm.where(pre_spike, x - u * self.x, x)
137
+ # else:
138
+ # u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u
139
+ # x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
140
+
141
+ # --- simplified code:
142
+ u = u + pre_spike * self.U * (1 - self.u.value)
143
+ x = x - pre_spike * u * self.x.value
144
+
145
+ self.u.value = u
146
+ self.x.value = x
147
+ return u * x * pre_spike
148
+
149
+
150
+ class STD(ShortTermPlasticity):
151
+ r"""
152
+ Synapse with short-term depression.
153
+
154
+ This class implements a synapse model with short-term depression (STD), which captures
155
+ activity-dependent reduction in synaptic efficacy, typically caused by depletion of
156
+ neurotransmitter vesicles following repeated stimulation.
157
+
158
+ The model is characterized by the following equation:
159
+
160
+ $$
161
+ \frac{dx}{dt} = \frac{1 - x}{\tau} - U \cdot x \cdot \delta(t - t_{spike})
162
+ $$
163
+
164
+ $$
165
+ g_{syn} = x
166
+ $$
167
+
168
+ where:
169
+ - $x$ represents the available synaptic resources (depression variable)
170
+ - $\tau$ is the depression recovery time constant
171
+ - $U$ is the utilization parameter (fraction of resources depleted per spike)
172
+ - $\delta(t - t_{spike})$ is the Dirac delta function representing presynaptic spikes
173
+ - $g_{syn}$ is the effective synaptic conductance
174
+
175
+ Parameters
176
+ ----------
177
+ in_size : Size
178
+ Size of the input.
179
+ name : str, optional
180
+ Name of the synapse instance.
181
+ tau : ArrayLike, default=200.*u.ms
182
+ Time constant governing recovery of synaptic resources in milliseconds.
183
+ U : ArrayLike, default=0.07
184
+ Utilization parameter (fraction of resources used per action potential).
185
+
186
+ Attributes
187
+ ----------
188
+ x : HiddenState
189
+ Available synaptic resources (depression variable).
190
+
191
+ Notes
192
+ -----
193
+ - Larger values of tau lead to slower recovery from depression.
194
+ - Larger values of U cause stronger depression with each spike.
195
+ - This model is a simplified version of the STP model that only includes depression.
196
+
197
+ References
198
+ ----------
199
+ .. [1] Abbott, L. F., Varela, J. A., Sen, K., & Nelson, S. B. (1997). Synaptic
200
+ depression and cortical gain control. Science, 275(5297), 220-224.
201
+ .. [2] Tsodyks, M. V., & Markram, H. (1997). The neural code between neocortical
202
+ pyramidal neurons depends on neurotransmitter release probability.
203
+ Proceedings of the National Academy of Sciences, 94(2), 719-723.
204
+ """
205
+ __module__ = 'brainstate.nn'
206
+
207
+ def __init__(
208
+ self,
209
+ in_size: Size,
210
+ name: Optional[str] = None,
211
+ # synapse parameters
212
+ tau: ArrayLike = 200. * u.ms,
213
+ U: ArrayLike = 0.07,
214
+ ):
215
+ super().__init__(name=name, in_size=in_size)
216
+
217
+ # parameters
218
+ self.tau = init.param(tau, self.varshape)
219
+ self.U = init.param(U, self.varshape)
220
+
221
+ def init_state(self, batch_size: int = None, **kwargs):
222
+ self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
223
+
224
+ def reset_state(self, batch_size: int = None, **kwargs):
225
+ self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
226
+
227
+ def update(self, pre_spike):
228
+ x = exp_euler_step(lambda x: (1 - x) / self.tau, self.x.value)
229
+
230
+ # --- original code:
231
+ # self.x.value = bm.where(pre_spike, x - self.U * self.x, x)
232
+
233
+ # --- simplified code:
234
+ self.x.value = x - pre_spike * self.U * self.x.value
235
+
236
+ return self.x.value * pre_spike
@@ -23,12 +23,12 @@ import brainunit as u
23
23
  from brainstate import init, environ
24
24
  from brainstate._state import ShortTermState, HiddenState
25
25
  from brainstate.mixin import AlignPost
26
- from brainstate.nn._dynamics._dynamics_base import Dynamics
27
- from brainstate.nn._exp_euler import exp_euler_step
28
- from brainstate.typing import ArrayLike, Size
26
+ from brainstate.typing import ArrayLike, Size, PyTree
27
+ from ._dynamics import Dynamics
28
+ from ._exp_euler import exp_euler_step
29
29
 
30
30
  __all__ = [
31
- 'Synapse', 'Expon', 'DualExpon', 'Alpha', 'STP', 'STD', 'AMPA', 'GABAa',
31
+ 'Synapse', 'Expon', 'DualExpon', 'Alpha', 'AMPA', 'GABAa',
32
32
  ]
33
33
 
34
34
 
@@ -123,6 +123,9 @@ class Expon(Synapse, AlignPost):
123
123
  g = exp_euler_step(lambda g: self.sum_current_inputs(-g) / self.tau, self.g.value)
124
124
  self.g.value = self.sum_delta_inputs(g)
125
125
  if x is not None: self.g.value += x
126
+ return self.update_return()
127
+
128
+ def update_return(self) -> PyTree:
126
129
  return self.g.value
127
130
 
128
131
 
@@ -229,6 +232,9 @@ class DualExpon(Synapse, AlignPost):
229
232
  if x is not None:
230
233
  self.g_rise.value += x
231
234
  self.g_decay.value += x
235
+ return self.update_return()
236
+
237
+ def update_return(self) -> PyTree:
232
238
  return self.a * (self.g_decay.value - self.g_rise.value)
233
239
 
234
240
 
@@ -301,209 +307,10 @@ class Alpha(Synapse):
301
307
  self.h.value = self.sum_delta_inputs(h)
302
308
  if x is not None:
303
309
  self.h.value += x
304
- return self.g.value
305
-
306
-
307
- class STP(Synapse):
308
- r"""
309
- Synapse with short-term plasticity.
310
-
311
- This class implements a synapse model with short-term plasticity (STP), which captures
312
- activity-dependent changes in synaptic efficacy that occur over milliseconds to seconds.
313
- The model simultaneously accounts for both short-term facilitation and depression
314
- based on the formulation by Tsodyks & Markram (1998).
315
-
316
- The model is characterized by the following equations:
317
-
318
- $$
319
- \frac{du}{dt} = -\frac{u}{\tau_f} + U \cdot (1 - u) \cdot \delta(t - t_{spike})
320
- $$
321
-
322
- $$
323
- \frac{dx}{dt} = \frac{1 - x}{\tau_d} - u \cdot x \cdot \delta(t - t_{spike})
324
- $$
325
-
326
- $$
327
- g_{syn} = u \cdot x
328
- $$
329
-
330
- where:
331
- - $u$ represents the utilization of synaptic efficacy (facilitation variable)
332
- - $x$ represents the available synaptic resources (depression variable)
333
- - $\tau_f$ is the facilitation time constant
334
- - $\tau_d$ is the depression time constant
335
- - $U$ is the baseline utilization parameter
336
- - $\delta(t - t_{spike})$ is the Dirac delta function representing presynaptic spikes
337
- - $g_{syn}$ is the effective synaptic conductance
338
-
339
- Parameters
340
- ----------
341
- in_size : Size
342
- Size of the input.
343
- name : str, optional
344
- Name of the synapse instance.
345
- U : ArrayLike, default=0.15
346
- Baseline utilization parameter (fraction of resources used per action potential).
347
- tau_f : ArrayLike, default=1500.*u.ms
348
- Time constant of short-term facilitation in milliseconds.
349
- tau_d : ArrayLike, default=200.*u.ms
350
- Time constant of short-term depression (recovery of synaptic resources) in milliseconds.
351
-
352
- Attributes
353
- ----------
354
- u : HiddenState
355
- Utilization of synaptic efficacy (facilitation variable).
356
- x : HiddenState
357
- Available synaptic resources (depression variable).
358
-
359
- Notes
360
- -----
361
- - Larger values of tau_f produce stronger facilitation effects.
362
- - Larger values of tau_d lead to slower recovery from depression.
363
- - The parameter U controls the initial release probability.
364
- - The effective synaptic strength is the product of u and x.
365
-
366
- References
367
- ----------
368
- .. [1] Tsodyks, M. V., & Markram, H. (1997). The neural code between neocortical
369
- pyramidal neurons depends on neurotransmitter release probability.
370
- Proceedings of the National Academy of Sciences, 94(2), 719-723.
371
- .. [2] Tsodyks, M., Pawelzik, K., & Markram, H. (1998). Neural networks with dynamic
372
- synapses. Neural computation, 10(4), 821-835.
373
- """
374
- __module__ = 'brainstate.nn'
375
-
376
- def __init__(
377
- self,
378
- in_size: Size,
379
- name: Optional[str] = None,
380
- U: ArrayLike = 0.15,
381
- tau_f: ArrayLike = 1500. * u.ms,
382
- tau_d: ArrayLike = 200. * u.ms,
383
- ):
384
- super().__init__(name=name, in_size=in_size)
385
-
386
- # parameters
387
- self.tau_f = init.param(tau_f, self.varshape)
388
- self.tau_d = init.param(tau_d, self.varshape)
389
- self.U = init.param(U, self.varshape)
390
-
391
- def init_state(self, batch_size: int = None, **kwargs):
392
- self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
393
- self.u = HiddenState(init.param(init.Constant(self.U), self.varshape, batch_size))
394
-
395
- def reset_state(self, batch_size: int = None, **kwargs):
396
- self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
397
- self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
398
-
399
- def update(self, pre_spike):
400
- u = exp_euler_step(lambda u: - u / self.tau_f, self.u.value)
401
- x = exp_euler_step(lambda x: (1 - x) / self.tau_d, self.x.value)
402
-
403
- # --- original code:
404
- # if pre_spike.dtype == jax.numpy.bool_:
405
- # u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
406
- # x = bm.where(pre_spike, x - u * self.x, x)
407
- # else:
408
- # u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u
409
- # x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
310
+ return self.update_return()
410
311
 
411
- # --- simplified code:
412
- u = u + pre_spike * self.U * (1 - self.u.value)
413
- x = x - pre_spike * u * self.x.value
414
-
415
- self.u.value = u
416
- self.x.value = x
417
- return u * x * pre_spike
418
-
419
-
420
- class STD(Synapse):
421
- r"""
422
- Synapse with short-term depression.
423
-
424
- This class implements a synapse model with short-term depression (STD), which captures
425
- activity-dependent reduction in synaptic efficacy, typically caused by depletion of
426
- neurotransmitter vesicles following repeated stimulation.
427
-
428
- The model is characterized by the following equation:
429
-
430
- $$
431
- \frac{dx}{dt} = \frac{1 - x}{\tau} - U \cdot x \cdot \delta(t - t_{spike})
432
- $$
433
-
434
- $$
435
- g_{syn} = x
436
- $$
437
-
438
- where:
439
- - $x$ represents the available synaptic resources (depression variable)
440
- - $\tau$ is the depression recovery time constant
441
- - $U$ is the utilization parameter (fraction of resources depleted per spike)
442
- - $\delta(t - t_{spike})$ is the Dirac delta function representing presynaptic spikes
443
- - $g_{syn}$ is the effective synaptic conductance
444
-
445
- Parameters
446
- ----------
447
- in_size : Size
448
- Size of the input.
449
- name : str, optional
450
- Name of the synapse instance.
451
- tau : ArrayLike, default=200.*u.ms
452
- Time constant governing recovery of synaptic resources in milliseconds.
453
- U : ArrayLike, default=0.07
454
- Utilization parameter (fraction of resources used per action potential).
455
-
456
- Attributes
457
- ----------
458
- x : HiddenState
459
- Available synaptic resources (depression variable).
460
-
461
- Notes
462
- -----
463
- - Larger values of tau lead to slower recovery from depression.
464
- - Larger values of U cause stronger depression with each spike.
465
- - This model is a simplified version of the STP model that only includes depression.
466
-
467
- References
468
- ----------
469
- .. [1] Abbott, L. F., Varela, J. A., Sen, K., & Nelson, S. B. (1997). Synaptic
470
- depression and cortical gain control. Science, 275(5297), 220-224.
471
- .. [2] Tsodyks, M. V., & Markram, H. (1997). The neural code between neocortical
472
- pyramidal neurons depends on neurotransmitter release probability.
473
- Proceedings of the National Academy of Sciences, 94(2), 719-723.
474
- """
475
- __module__ = 'brainstate.nn'
476
-
477
- def __init__(
478
- self,
479
- in_size: Size,
480
- name: Optional[str] = None,
481
- # synapse parameters
482
- tau: ArrayLike = 200. * u.ms,
483
- U: ArrayLike = 0.07,
484
- ):
485
- super().__init__(name=name, in_size=in_size)
486
-
487
- # parameters
488
- self.tau = init.param(tau, self.varshape)
489
- self.U = init.param(U, self.varshape)
490
-
491
- def init_state(self, batch_size: int = None, **kwargs):
492
- self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
493
-
494
- def reset_state(self, batch_size: int = None, **kwargs):
495
- self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
496
-
497
- def update(self, pre_spike):
498
- x = exp_euler_step(lambda x: (1 - x) / self.tau, self.x.value)
499
-
500
- # --- original code:
501
- # self.x.value = bm.where(pre_spike, x - self.U * self.x, x)
502
-
503
- # --- simplified code:
504
- self.x.value = x - pre_spike * self.U * self.x.value
505
-
506
- return self.x.value * pre_spike
312
+ def update_return(self) -> PyTree:
313
+ return self.g.value
507
314
 
508
315
 
509
316
  class AMPA(Synapse):
@@ -614,6 +421,10 @@ class AMPA(Synapse):
614
421
  self.spike_arrival_time.value = u.math.where(pre_spike, t, self.spike_arrival_time.value)
615
422
  TT = ((t - self.spike_arrival_time.value) < self.T_duration) * self.T
616
423
  self.g.value = exp_euler_step(self.dg, self.g.value, t, TT)
424
+ return self.update_return()
425
+
426
+ def update_return(self) -> PyTree:
427
+ """Return the synaptic conductance value."""
617
428
  return self.g.value
618
429
 
619
430