brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,63 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import unittest
18
-
19
- import jax.numpy as jnp
20
-
21
- import brainstate
22
-
23
-
24
- class TestRateRNNModels(unittest.TestCase):
25
- def setUp(self):
26
- self.num_in = 3
27
- self.num_out = 3
28
- self.batch_size = 4
29
- self.x = jnp.ones((self.batch_size, self.num_in))
30
-
31
- def test_ValinaRNNCell(self):
32
- model = brainstate.nn.ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
33
- model.init_state(batch_size=self.batch_size)
34
- output = model.update(self.x)
35
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
36
-
37
- def test_GRUCell(self):
38
- model = brainstate.nn.GRUCell(num_in=self.num_in, num_out=self.num_out)
39
- model.init_state(batch_size=self.batch_size)
40
- output = model.update(self.x)
41
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
42
-
43
- def test_MGUCell(self):
44
- model = brainstate.nn.MGUCell(num_in=self.num_in, num_out=self.num_out)
45
- model.init_state(batch_size=self.batch_size)
46
- output = model.update(self.x)
47
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
48
-
49
- def test_LSTMCell(self):
50
- model = brainstate.nn.LSTMCell(num_in=self.num_in, num_out=self.num_out)
51
- model.init_state(batch_size=self.batch_size)
52
- output = model.update(self.x)
53
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
54
-
55
- def test_URLSTMCell(self):
56
- model = brainstate.nn.URLSTMCell(num_in=self.num_in, num_out=self.num_out)
57
- model.init_state(batch_size=self.batch_size)
58
- output = model.update(self.x)
59
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
60
-
61
-
62
- if __name__ == '__main__':
63
- unittest.main()
brainstate/nn/_readout.py DELETED
@@ -1,209 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
-
19
- import numbers
20
- from typing import Callable
21
-
22
- import brainunit as u
23
- import jax
24
-
25
- from brainstate import environ, init, surrogate
26
- from brainstate._state import HiddenState, ParamState
27
- from brainstate.typing import Size, ArrayLike
28
- from ._exp_euler import exp_euler_step
29
- from ._module import Module
30
- from ._neuron import Neuron
31
-
32
- __all__ = [
33
- 'LeakyRateReadout',
34
- 'LeakySpikeReadout',
35
- ]
36
-
37
-
38
- class LeakyRateReadout(Module):
39
- r"""
40
- Leaky dynamics for the read-out module.
41
-
42
- This module implements a leaky integrator with the following dynamics:
43
-
44
- .. math::
45
- r_{t} = \alpha r_{t-1} + x_{t} W
46
-
47
- where:
48
- - :math:`r_{t}` is the output at time t
49
- - :math:`\alpha = e^{-\Delta t / \tau}` is the decay factor
50
- - :math:`x_{t}` is the input at time t
51
- - :math:`W` is the weight matrix
52
-
53
- The leaky integrator acts as a low-pass filter, allowing the network
54
- to maintain memory of past inputs with an exponential decay determined
55
- by the time constant tau.
56
-
57
- Parameters
58
- ----------
59
- in_size : int or sequence of int
60
- Size of the input dimension(s)
61
- out_size : int or sequence of int
62
- Size of the output dimension(s)
63
- tau : ArrayLike, optional
64
- Time constant of the leaky dynamics, by default 5ms
65
- w_init : Callable, optional
66
- Weight initialization function, by default KaimingNormal()
67
- name : str, optional
68
- Name of the module, by default None
69
-
70
- Attributes
71
- ----------
72
- decay : float
73
- Decay factor computed as exp(-dt/tau)
74
- weight : ParamState
75
- Weight matrix connecting input to output
76
- r : HiddenState
77
- Hidden state representing the output values
78
- """
79
- __module__ = 'brainstate.nn'
80
-
81
- def __init__(
82
- self,
83
- in_size: Size,
84
- out_size: Size,
85
- tau: ArrayLike = 5. * u.ms,
86
- w_init: Callable = init.KaimingNormal(),
87
- name: str = None,
88
- ):
89
- super().__init__(name=name)
90
-
91
- # parameters
92
- self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
93
- self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
94
- self.tau = init.param(tau, self.in_size)
95
- self.decay = u.math.exp(-environ.get_dt() / self.tau)
96
-
97
- # weights
98
- self.weight = ParamState(init.param(w_init, (self.in_size[0], self.out_size[0])))
99
-
100
- def init_state(self, batch_size=None, **kwargs):
101
- self.r = HiddenState(init.param(init.Constant(0.), self.out_size, batch_size))
102
-
103
- def reset_state(self, batch_size=None, **kwargs):
104
- self.r.value = init.param(init.Constant(0.), self.out_size, batch_size)
105
-
106
- def update(self, x):
107
- self.r.value = self.decay * self.r.value + x @ self.weight.value
108
- return self.r.value
109
-
110
-
111
- class LeakySpikeReadout(Neuron):
112
- r"""
113
- Integrate-and-fire neuron model with leaky dynamics for readout functionality.
114
-
115
- This class implements a spiking neuron with the following dynamics:
116
-
117
- .. math::
118
- \frac{dV}{dt} = \frac{-V + I_{in}}{\tau}
119
-
120
- where:
121
- - :math:`V` is the membrane potential
122
- - :math:`\tau` is the membrane time constant
123
- - :math:`I_{in}` is the input current
124
-
125
- Spike generation occurs when :math:`V > V_{th}` according to:
126
-
127
- .. math::
128
- S_t = \text{surrogate}\left(\frac{V - V_{th}}{V_{th}}\right)
129
-
130
- After spiking, the membrane potential is reset according to the reset mode:
131
- - Soft reset: :math:`V \leftarrow V - V_{th} \cdot S_t`
132
- - Hard reset: :math:`V \leftarrow V - V_t \cdot S_t` (where :math:`V_t` is detached)
133
-
134
- Parameters
135
- ----------
136
- in_size : Size
137
- Size of the input dimension
138
- tau : ArrayLike, optional
139
- Membrane time constant, by default 5ms
140
- V_th : ArrayLike, optional
141
- Spike threshold, by default 1mV
142
- w_init : Callable, optional
143
- Weight initialization function, by default KaimingNormal(unit=mV)
144
- V_initializer : ArrayLike, optional
145
- Initial membrane potential, by default ZeroInit(unit=mV)
146
- spk_fun : Callable, optional
147
- Surrogate gradient function for spike generation, by default ReluGrad()
148
- spk_reset : str, optional
149
- Reset mechanism after spike ('soft' or 'hard'), by default 'soft'
150
- name : str, optional
151
- Name of the module, by default None
152
-
153
- Attributes
154
- ----------
155
- V : HiddenState
156
- Membrane potential state variable
157
- weight : ParamState
158
- Synaptic weight matrix
159
- """
160
-
161
- __module__ = 'brainstate.nn'
162
-
163
- def __init__(
164
- self,
165
- in_size: Size,
166
- tau: ArrayLike = 5. * u.ms,
167
- V_th: ArrayLike = 1. * u.mV,
168
- w_init: Callable = init.KaimingNormal(unit=u.mV),
169
- V_initializer: ArrayLike = init.ZeroInit(unit=u.mV),
170
- spk_fun: Callable = surrogate.ReluGrad(),
171
- spk_reset: str = 'soft',
172
- name: str = None,
173
- ):
174
- super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
175
-
176
- # parameters
177
- self.tau = init.param(tau, self.varshape)
178
- self.V_th = init.param(V_th, self.varshape)
179
- self.V_initializer = V_initializer
180
-
181
- # weights
182
- self.weight = ParamState(init.param(w_init, (self.in_size[-1], self.out_size[-1])))
183
-
184
- def init_state(self, batch_size, **kwargs):
185
- self.V = HiddenState(init.param(self.V_initializer, self.varshape, batch_size))
186
-
187
- def reset_state(self, batch_size, **kwargs):
188
- self.V.value = init.param(self.V_initializer, self.varshape, batch_size)
189
-
190
- @property
191
- def spike(self):
192
- return self.get_spike(self.V.value)
193
-
194
- def get_spike(self, V):
195
- v_scaled = (V - self.V_th) / self.V_th
196
- return self.spk_fun(v_scaled)
197
-
198
- def update(self, spk):
199
- # reset
200
- last_V = self.V.value
201
- last_spike = self.get_spike(last_V)
202
- V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_V)
203
- V = last_V - V_th * last_spike
204
- # membrane potential
205
- x = spk @ self.weight.value
206
- dv = lambda v: (-v + self.sum_current_inputs(x, v)) / self.tau
207
- V = exp_euler_step(dv, V)
208
- self.V.value = self.sum_delta_inputs(V)
209
- return self.get_spike(V)
@@ -1,53 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import unittest
18
-
19
- import jax.numpy as jnp
20
-
21
- import brainstate
22
-
23
-
24
- class TestReadoutModels(unittest.TestCase):
25
- def setUp(self):
26
- self.in_size = 3
27
- self.out_size = 3
28
- self.batch_size = 4
29
- self.tau = 5.0
30
- self.V_th = 1.0
31
- self.x = jnp.ones((self.batch_size, self.in_size))
32
-
33
- def test_LeakyRateReadout(self):
34
- with brainstate.environ.context(dt=0.1):
35
- model = brainstate.nn.LeakyRateReadout(in_size=self.in_size, out_size=self.out_size, tau=self.tau)
36
- model.init_state(batch_size=self.batch_size)
37
- output = model.update(self.x)
38
- self.assertEqual(output.shape, (self.batch_size, self.out_size))
39
-
40
- def test_LeakySpikeReadout(self):
41
- with brainstate.environ.context(dt=0.1):
42
- model = brainstate.nn.LeakySpikeReadout(in_size=self.in_size, tau=self.tau, V_th=self.V_th,
43
- V_initializer=brainstate.init.ZeroInit(),
44
- w_init=brainstate.init.KaimingNormal())
45
- model.init_state(batch_size=self.batch_size)
46
- with brainstate.environ.context(t=0.):
47
- output = model.update(self.x)
48
- self.assertEqual(output.shape, (self.batch_size, self.out_size))
49
-
50
-
51
- if __name__ == '__main__':
52
- with brainstate.environ.context(dt=0.1):
53
- unittest.main()
brainstate/nn/_stp.py DELETED
@@ -1,236 +0,0 @@
1
- # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from typing import Optional
19
-
20
- import brainunit as u
21
-
22
- from brainstate import init
23
- from brainstate._state import HiddenState
24
- from brainstate.typing import ArrayLike, Size
25
- from ._exp_euler import exp_euler_step
26
- from ._synapse import Synapse
27
-
28
- __all__ = [
29
- 'ShortTermPlasticity', 'STP', 'STD',
30
- ]
31
-
32
-
33
- class ShortTermPlasticity(Synapse):
34
- pass
35
-
36
-
37
- class STP(ShortTermPlasticity):
38
- r"""
39
- Synapse with short-term plasticity.
40
-
41
- This class implements a synapse model with short-term plasticity (STP), which captures
42
- activity-dependent changes in synaptic efficacy that occur over milliseconds to seconds.
43
- The model simultaneously accounts for both short-term facilitation and depression
44
- based on the formulation by Tsodyks & Markram (1998).
45
-
46
- The model is characterized by the following equations:
47
-
48
- $$
49
- \frac{du}{dt} = -\frac{u}{\tau_f} + U \cdot (1 - u) \cdot \delta(t - t_{spike})
50
- $$
51
-
52
- $$
53
- \frac{dx}{dt} = \frac{1 - x}{\tau_d} - u \cdot x \cdot \delta(t - t_{spike})
54
- $$
55
-
56
- $$
57
- g_{syn} = u \cdot x
58
- $$
59
-
60
- where:
61
- - $u$ represents the utilization of synaptic efficacy (facilitation variable)
62
- - $x$ represents the available synaptic resources (depression variable)
63
- - $\tau_f$ is the facilitation time constant
64
- - $\tau_d$ is the depression time constant
65
- - $U$ is the baseline utilization parameter
66
- - $\delta(t - t_{spike})$ is the Dirac delta function representing presynaptic spikes
67
- - $g_{syn}$ is the effective synaptic conductance
68
-
69
- Parameters
70
- ----------
71
- in_size : Size
72
- Size of the input.
73
- name : str, optional
74
- Name of the synapse instance.
75
- U : ArrayLike, default=0.15
76
- Baseline utilization parameter (fraction of resources used per action potential).
77
- tau_f : ArrayLike, default=1500.*u.ms
78
- Time constant of short-term facilitation in milliseconds.
79
- tau_d : ArrayLike, default=200.*u.ms
80
- Time constant of short-term depression (recovery of synaptic resources) in milliseconds.
81
-
82
- Attributes
83
- ----------
84
- u : HiddenState
85
- Utilization of synaptic efficacy (facilitation variable).
86
- x : HiddenState
87
- Available synaptic resources (depression variable).
88
-
89
- Notes
90
- -----
91
- - Larger values of tau_f produce stronger facilitation effects.
92
- - Larger values of tau_d lead to slower recovery from depression.
93
- - The parameter U controls the initial release probability.
94
- - The effective synaptic strength is the product of u and x.
95
-
96
- References
97
- ----------
98
- .. [1] Tsodyks, M. V., & Markram, H. (1997). The neural code between neocortical
99
- pyramidal neurons depends on neurotransmitter release probability.
100
- Proceedings of the National Academy of Sciences, 94(2), 719-723.
101
- .. [2] Tsodyks, M., Pawelzik, K., & Markram, H. (1998). Neural networks with dynamic
102
- synapses. Neural computation, 10(4), 821-835.
103
- """
104
- __module__ = 'brainstate.nn'
105
-
106
- def __init__(
107
- self,
108
- in_size: Size,
109
- name: Optional[str] = None,
110
- U: ArrayLike = 0.15,
111
- tau_f: ArrayLike = 1500. * u.ms,
112
- tau_d: ArrayLike = 200. * u.ms,
113
- ):
114
- super().__init__(name=name, in_size=in_size)
115
-
116
- # parameters
117
- self.tau_f = init.param(tau_f, self.varshape)
118
- self.tau_d = init.param(tau_d, self.varshape)
119
- self.U = init.param(U, self.varshape)
120
-
121
- def init_state(self, batch_size: int = None, **kwargs):
122
- self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
123
- self.u = HiddenState(init.param(init.Constant(self.U), self.varshape, batch_size))
124
-
125
- def reset_state(self, batch_size: int = None, **kwargs):
126
- self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
127
- self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
128
-
129
- def update(self, pre_spike):
130
- u = exp_euler_step(lambda u: - u / self.tau_f, self.u.value)
131
- x = exp_euler_step(lambda x: (1 - x) / self.tau_d, self.x.value)
132
-
133
- # --- original code:
134
- # if pre_spike.dtype == jax.numpy.bool_:
135
- # u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
136
- # x = bm.where(pre_spike, x - u * self.x, x)
137
- # else:
138
- # u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u
139
- # x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
140
-
141
- # --- simplified code:
142
- u = u + pre_spike * self.U * (1 - self.u.value)
143
- x = x - pre_spike * u * self.x.value
144
-
145
- self.u.value = u
146
- self.x.value = x
147
- return u * x * pre_spike
148
-
149
-
150
- class STD(ShortTermPlasticity):
151
- r"""
152
- Synapse with short-term depression.
153
-
154
- This class implements a synapse model with short-term depression (STD), which captures
155
- activity-dependent reduction in synaptic efficacy, typically caused by depletion of
156
- neurotransmitter vesicles following repeated stimulation.
157
-
158
- The model is characterized by the following equation:
159
-
160
- $$
161
- \frac{dx}{dt} = \frac{1 - x}{\tau} - U \cdot x \cdot \delta(t - t_{spike})
162
- $$
163
-
164
- $$
165
- g_{syn} = x
166
- $$
167
-
168
- where:
169
- - $x$ represents the available synaptic resources (depression variable)
170
- - $\tau$ is the depression recovery time constant
171
- - $U$ is the utilization parameter (fraction of resources depleted per spike)
172
- - $\delta(t - t_{spike})$ is the Dirac delta function representing presynaptic spikes
173
- - $g_{syn}$ is the effective synaptic conductance
174
-
175
- Parameters
176
- ----------
177
- in_size : Size
178
- Size of the input.
179
- name : str, optional
180
- Name of the synapse instance.
181
- tau : ArrayLike, default=200.*u.ms
182
- Time constant governing recovery of synaptic resources in milliseconds.
183
- U : ArrayLike, default=0.07
184
- Utilization parameter (fraction of resources used per action potential).
185
-
186
- Attributes
187
- ----------
188
- x : HiddenState
189
- Available synaptic resources (depression variable).
190
-
191
- Notes
192
- -----
193
- - Larger values of tau lead to slower recovery from depression.
194
- - Larger values of U cause stronger depression with each spike.
195
- - This model is a simplified version of the STP model that only includes depression.
196
-
197
- References
198
- ----------
199
- .. [1] Abbott, L. F., Varela, J. A., Sen, K., & Nelson, S. B. (1997). Synaptic
200
- depression and cortical gain control. Science, 275(5297), 220-224.
201
- .. [2] Tsodyks, M. V., & Markram, H. (1997). The neural code between neocortical
202
- pyramidal neurons depends on neurotransmitter release probability.
203
- Proceedings of the National Academy of Sciences, 94(2), 719-723.
204
- """
205
- __module__ = 'brainstate.nn'
206
-
207
- def __init__(
208
- self,
209
- in_size: Size,
210
- name: Optional[str] = None,
211
- # synapse parameters
212
- tau: ArrayLike = 200. * u.ms,
213
- U: ArrayLike = 0.07,
214
- ):
215
- super().__init__(name=name, in_size=in_size)
216
-
217
- # parameters
218
- self.tau = init.param(tau, self.varshape)
219
- self.U = init.param(U, self.varshape)
220
-
221
- def init_state(self, batch_size: int = None, **kwargs):
222
- self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
223
-
224
- def reset_state(self, batch_size: int = None, **kwargs):
225
- self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
226
-
227
- def update(self, pre_spike):
228
- x = exp_euler_step(lambda x: (1 - x) / self.tau, self.x.value)
229
-
230
- # --- original code:
231
- # self.x.value = bm.where(pre_spike, x - self.U * self.x, x)
232
-
233
- # --- simplified code:
234
- self.x.value = x - pre_spike * self.U * self.x.value
235
-
236
- return self.x.value * pre_spike