brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,63 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import unittest
18
-
19
- import jax.numpy as jnp
20
-
21
- import brainstate
22
-
23
-
24
- class TestRateRNNModels(unittest.TestCase):
25
- def setUp(self):
26
- self.num_in = 3
27
- self.num_out = 3
28
- self.batch_size = 4
29
- self.x = jnp.ones((self.batch_size, self.num_in))
30
-
31
- def test_ValinaRNNCell(self):
32
- model = brainstate.nn.ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
33
- model.init_state(batch_size=self.batch_size)
34
- output = model.update(self.x)
35
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
36
-
37
- def test_GRUCell(self):
38
- model = brainstate.nn.GRUCell(num_in=self.num_in, num_out=self.num_out)
39
- model.init_state(batch_size=self.batch_size)
40
- output = model.update(self.x)
41
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
42
-
43
- def test_MGUCell(self):
44
- model = brainstate.nn.MGUCell(num_in=self.num_in, num_out=self.num_out)
45
- model.init_state(batch_size=self.batch_size)
46
- output = model.update(self.x)
47
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
48
-
49
- def test_LSTMCell(self):
50
- model = brainstate.nn.LSTMCell(num_in=self.num_in, num_out=self.num_out)
51
- model.init_state(batch_size=self.batch_size)
52
- output = model.update(self.x)
53
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
54
-
55
- def test_URLSTMCell(self):
56
- model = brainstate.nn.URLSTMCell(num_in=self.num_in, num_out=self.num_out)
57
- model.init_state(batch_size=self.batch_size)
58
- output = model.update(self.x)
59
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
60
-
61
-
62
- if __name__ == '__main__':
63
- unittest.main()
brainstate/nn/_readout.py DELETED
@@ -1,209 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
-
19
- import numbers
20
- from typing import Callable
21
-
22
- import brainunit as u
23
- import jax
24
-
25
- from brainstate import environ, init, surrogate
26
- from brainstate._state import HiddenState, ParamState
27
- from brainstate.typing import Size, ArrayLike
28
- from ._exp_euler import exp_euler_step
29
- from ._module import Module
30
- from ._neuron import Neuron
31
-
32
- __all__ = [
33
- 'LeakyRateReadout',
34
- 'LeakySpikeReadout',
35
- ]
36
-
37
-
38
- class LeakyRateReadout(Module):
39
- r"""
40
- Leaky dynamics for the read-out module.
41
-
42
- This module implements a leaky integrator with the following dynamics:
43
-
44
- .. math::
45
- r_{t} = \alpha r_{t-1} + x_{t} W
46
-
47
- where:
48
- - :math:`r_{t}` is the output at time t
49
- - :math:`\alpha = e^{-\Delta t / \tau}` is the decay factor
50
- - :math:`x_{t}` is the input at time t
51
- - :math:`W` is the weight matrix
52
-
53
- The leaky integrator acts as a low-pass filter, allowing the network
54
- to maintain memory of past inputs with an exponential decay determined
55
- by the time constant tau.
56
-
57
- Parameters
58
- ----------
59
- in_size : int or sequence of int
60
- Size of the input dimension(s)
61
- out_size : int or sequence of int
62
- Size of the output dimension(s)
63
- tau : ArrayLike, optional
64
- Time constant of the leaky dynamics, by default 5ms
65
- w_init : Callable, optional
66
- Weight initialization function, by default KaimingNormal()
67
- name : str, optional
68
- Name of the module, by default None
69
-
70
- Attributes
71
- ----------
72
- decay : float
73
- Decay factor computed as exp(-dt/tau)
74
- weight : ParamState
75
- Weight matrix connecting input to output
76
- r : HiddenState
77
- Hidden state representing the output values
78
- """
79
- __module__ = 'brainstate.nn'
80
-
81
- def __init__(
82
- self,
83
- in_size: Size,
84
- out_size: Size,
85
- tau: ArrayLike = 5. * u.ms,
86
- w_init: Callable = init.KaimingNormal(),
87
- name: str = None,
88
- ):
89
- super().__init__(name=name)
90
-
91
- # parameters
92
- self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
93
- self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
94
- self.tau = init.param(tau, self.in_size)
95
- self.decay = u.math.exp(-environ.get_dt() / self.tau)
96
-
97
- # weights
98
- self.weight = ParamState(init.param(w_init, (self.in_size[0], self.out_size[0])))
99
-
100
- def init_state(self, batch_size=None, **kwargs):
101
- self.r = HiddenState(init.param(init.Constant(0.), self.out_size, batch_size))
102
-
103
- def reset_state(self, batch_size=None, **kwargs):
104
- self.r.value = init.param(init.Constant(0.), self.out_size, batch_size)
105
-
106
- def update(self, x):
107
- self.r.value = self.decay * self.r.value + x @ self.weight.value
108
- return self.r.value
109
-
110
-
111
- class LeakySpikeReadout(Neuron):
112
- r"""
113
- Integrate-and-fire neuron model with leaky dynamics for readout functionality.
114
-
115
- This class implements a spiking neuron with the following dynamics:
116
-
117
- .. math::
118
- \frac{dV}{dt} = \frac{-V + I_{in}}{\tau}
119
-
120
- where:
121
- - :math:`V` is the membrane potential
122
- - :math:`\tau` is the membrane time constant
123
- - :math:`I_{in}` is the input current
124
-
125
- Spike generation occurs when :math:`V > V_{th}` according to:
126
-
127
- .. math::
128
- S_t = \text{surrogate}\left(\frac{V - V_{th}}{V_{th}}\right)
129
-
130
- After spiking, the membrane potential is reset according to the reset mode:
131
- - Soft reset: :math:`V \leftarrow V - V_{th} \cdot S_t`
132
- - Hard reset: :math:`V \leftarrow V - V_t \cdot S_t` (where :math:`V_t` is detached)
133
-
134
- Parameters
135
- ----------
136
- in_size : Size
137
- Size of the input dimension
138
- tau : ArrayLike, optional
139
- Membrane time constant, by default 5ms
140
- V_th : ArrayLike, optional
141
- Spike threshold, by default 1mV
142
- w_init : Callable, optional
143
- Weight initialization function, by default KaimingNormal(unit=mV)
144
- V_initializer : ArrayLike, optional
145
- Initial membrane potential, by default ZeroInit(unit=mV)
146
- spk_fun : Callable, optional
147
- Surrogate gradient function for spike generation, by default ReluGrad()
148
- spk_reset : str, optional
149
- Reset mechanism after spike ('soft' or 'hard'), by default 'soft'
150
- name : str, optional
151
- Name of the module, by default None
152
-
153
- Attributes
154
- ----------
155
- V : HiddenState
156
- Membrane potential state variable
157
- weight : ParamState
158
- Synaptic weight matrix
159
- """
160
-
161
- __module__ = 'brainstate.nn'
162
-
163
- def __init__(
164
- self,
165
- in_size: Size,
166
- tau: ArrayLike = 5. * u.ms,
167
- V_th: ArrayLike = 1. * u.mV,
168
- w_init: Callable = init.KaimingNormal(unit=u.mV),
169
- V_initializer: ArrayLike = init.ZeroInit(unit=u.mV),
170
- spk_fun: Callable = surrogate.ReluGrad(),
171
- spk_reset: str = 'soft',
172
- name: str = None,
173
- ):
174
- super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
175
-
176
- # parameters
177
- self.tau = init.param(tau, self.varshape)
178
- self.V_th = init.param(V_th, self.varshape)
179
- self.V_initializer = V_initializer
180
-
181
- # weights
182
- self.weight = ParamState(init.param(w_init, (self.in_size[-1], self.out_size[-1])))
183
-
184
- def init_state(self, batch_size, **kwargs):
185
- self.V = HiddenState(init.param(self.V_initializer, self.varshape, batch_size))
186
-
187
- def reset_state(self, batch_size, **kwargs):
188
- self.V.value = init.param(self.V_initializer, self.varshape, batch_size)
189
-
190
- @property
191
- def spike(self):
192
- return self.get_spike(self.V.value)
193
-
194
- def get_spike(self, V):
195
- v_scaled = (V - self.V_th) / self.V_th
196
- return self.spk_fun(v_scaled)
197
-
198
- def update(self, spk):
199
- # reset
200
- last_V = self.V.value
201
- last_spike = self.get_spike(last_V)
202
- V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_V)
203
- V = last_V - V_th * last_spike
204
- # membrane potential
205
- x = spk @ self.weight.value
206
- dv = lambda v: (-v + self.sum_current_inputs(x, v)) / self.tau
207
- V = exp_euler_step(dv, V)
208
- self.V.value = self.sum_delta_inputs(V)
209
- return self.get_spike(V)
@@ -1,53 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import unittest
18
-
19
- import jax.numpy as jnp
20
-
21
- import brainstate
22
-
23
-
24
- class TestReadoutModels(unittest.TestCase):
25
- def setUp(self):
26
- self.in_size = 3
27
- self.out_size = 3
28
- self.batch_size = 4
29
- self.tau = 5.0
30
- self.V_th = 1.0
31
- self.x = jnp.ones((self.batch_size, self.in_size))
32
-
33
- def test_LeakyRateReadout(self):
34
- with brainstate.environ.context(dt=0.1):
35
- model = brainstate.nn.LeakyRateReadout(in_size=self.in_size, out_size=self.out_size, tau=self.tau)
36
- model.init_state(batch_size=self.batch_size)
37
- output = model.update(self.x)
38
- self.assertEqual(output.shape, (self.batch_size, self.out_size))
39
-
40
- def test_LeakySpikeReadout(self):
41
- with brainstate.environ.context(dt=0.1):
42
- model = brainstate.nn.LeakySpikeReadout(in_size=self.in_size, tau=self.tau, V_th=self.V_th,
43
- V_initializer=brainstate.init.ZeroInit(),
44
- w_init=brainstate.init.KaimingNormal())
45
- model.init_state(batch_size=self.batch_size)
46
- with brainstate.environ.context(t=0.):
47
- output = model.update(self.x)
48
- self.assertEqual(output.shape, (self.batch_size, self.out_size))
49
-
50
-
51
- if __name__ == '__main__':
52
- with brainstate.environ.context(dt=0.1):
53
- unittest.main()
brainstate/nn/_stp.py DELETED
@@ -1,236 +0,0 @@
1
- # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from typing import Optional
19
-
20
- import brainunit as u
21
-
22
- from brainstate import init
23
- from brainstate._state import HiddenState
24
- from brainstate.typing import ArrayLike, Size
25
- from ._exp_euler import exp_euler_step
26
- from ._synapse import Synapse
27
-
28
- __all__ = [
29
- 'ShortTermPlasticity', 'STP', 'STD',
30
- ]
31
-
32
-
33
- class ShortTermPlasticity(Synapse):
34
- pass
35
-
36
-
37
- class STP(ShortTermPlasticity):
38
- r"""
39
- Synapse with short-term plasticity.
40
-
41
- This class implements a synapse model with short-term plasticity (STP), which captures
42
- activity-dependent changes in synaptic efficacy that occur over milliseconds to seconds.
43
- The model simultaneously accounts for both short-term facilitation and depression
44
- based on the formulation by Tsodyks & Markram (1998).
45
-
46
- The model is characterized by the following equations:
47
-
48
- $$
49
- \frac{du}{dt} = -\frac{u}{\tau_f} + U \cdot (1 - u) \cdot \delta(t - t_{spike})
50
- $$
51
-
52
- $$
53
- \frac{dx}{dt} = \frac{1 - x}{\tau_d} - u \cdot x \cdot \delta(t - t_{spike})
54
- $$
55
-
56
- $$
57
- g_{syn} = u \cdot x
58
- $$
59
-
60
- where:
61
- - $u$ represents the utilization of synaptic efficacy (facilitation variable)
62
- - $x$ represents the available synaptic resources (depression variable)
63
- - $\tau_f$ is the facilitation time constant
64
- - $\tau_d$ is the depression time constant
65
- - $U$ is the baseline utilization parameter
66
- - $\delta(t - t_{spike})$ is the Dirac delta function representing presynaptic spikes
67
- - $g_{syn}$ is the effective synaptic conductance
68
-
69
- Parameters
70
- ----------
71
- in_size : Size
72
- Size of the input.
73
- name : str, optional
74
- Name of the synapse instance.
75
- U : ArrayLike, default=0.15
76
- Baseline utilization parameter (fraction of resources used per action potential).
77
- tau_f : ArrayLike, default=1500.*u.ms
78
- Time constant of short-term facilitation in milliseconds.
79
- tau_d : ArrayLike, default=200.*u.ms
80
- Time constant of short-term depression (recovery of synaptic resources) in milliseconds.
81
-
82
- Attributes
83
- ----------
84
- u : HiddenState
85
- Utilization of synaptic efficacy (facilitation variable).
86
- x : HiddenState
87
- Available synaptic resources (depression variable).
88
-
89
- Notes
90
- -----
91
- - Larger values of tau_f produce stronger facilitation effects.
92
- - Larger values of tau_d lead to slower recovery from depression.
93
- - The parameter U controls the initial release probability.
94
- - The effective synaptic strength is the product of u and x.
95
-
96
- References
97
- ----------
98
- .. [1] Tsodyks, M. V., & Markram, H. (1997). The neural code between neocortical
99
- pyramidal neurons depends on neurotransmitter release probability.
100
- Proceedings of the National Academy of Sciences, 94(2), 719-723.
101
- .. [2] Tsodyks, M., Pawelzik, K., & Markram, H. (1998). Neural networks with dynamic
102
- synapses. Neural computation, 10(4), 821-835.
103
- """
104
- __module__ = 'brainstate.nn'
105
-
106
- def __init__(
107
- self,
108
- in_size: Size,
109
- name: Optional[str] = None,
110
- U: ArrayLike = 0.15,
111
- tau_f: ArrayLike = 1500. * u.ms,
112
- tau_d: ArrayLike = 200. * u.ms,
113
- ):
114
- super().__init__(name=name, in_size=in_size)
115
-
116
- # parameters
117
- self.tau_f = init.param(tau_f, self.varshape)
118
- self.tau_d = init.param(tau_d, self.varshape)
119
- self.U = init.param(U, self.varshape)
120
-
121
- def init_state(self, batch_size: int = None, **kwargs):
122
- self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
123
- self.u = HiddenState(init.param(init.Constant(self.U), self.varshape, batch_size))
124
-
125
- def reset_state(self, batch_size: int = None, **kwargs):
126
- self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
127
- self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
128
-
129
- def update(self, pre_spike):
130
- u = exp_euler_step(lambda u: - u / self.tau_f, self.u.value)
131
- x = exp_euler_step(lambda x: (1 - x) / self.tau_d, self.x.value)
132
-
133
- # --- original code:
134
- # if pre_spike.dtype == jax.numpy.bool_:
135
- # u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
136
- # x = bm.where(pre_spike, x - u * self.x, x)
137
- # else:
138
- # u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u
139
- # x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
140
-
141
- # --- simplified code:
142
- u = u + pre_spike * self.U * (1 - self.u.value)
143
- x = x - pre_spike * u * self.x.value
144
-
145
- self.u.value = u
146
- self.x.value = x
147
- return u * x * pre_spike
148
-
149
-
150
- class STD(ShortTermPlasticity):
151
- r"""
152
- Synapse with short-term depression.
153
-
154
- This class implements a synapse model with short-term depression (STD), which captures
155
- activity-dependent reduction in synaptic efficacy, typically caused by depletion of
156
- neurotransmitter vesicles following repeated stimulation.
157
-
158
- The model is characterized by the following equation:
159
-
160
- $$
161
- \frac{dx}{dt} = \frac{1 - x}{\tau} - U \cdot x \cdot \delta(t - t_{spike})
162
- $$
163
-
164
- $$
165
- g_{syn} = x
166
- $$
167
-
168
- where:
169
- - $x$ represents the available synaptic resources (depression variable)
170
- - $\tau$ is the depression recovery time constant
171
- - $U$ is the utilization parameter (fraction of resources depleted per spike)
172
- - $\delta(t - t_{spike})$ is the Dirac delta function representing presynaptic spikes
173
- - $g_{syn}$ is the effective synaptic conductance
174
-
175
- Parameters
176
- ----------
177
- in_size : Size
178
- Size of the input.
179
- name : str, optional
180
- Name of the synapse instance.
181
- tau : ArrayLike, default=200.*u.ms
182
- Time constant governing recovery of synaptic resources in milliseconds.
183
- U : ArrayLike, default=0.07
184
- Utilization parameter (fraction of resources used per action potential).
185
-
186
- Attributes
187
- ----------
188
- x : HiddenState
189
- Available synaptic resources (depression variable).
190
-
191
- Notes
192
- -----
193
- - Larger values of tau lead to slower recovery from depression.
194
- - Larger values of U cause stronger depression with each spike.
195
- - This model is a simplified version of the STP model that only includes depression.
196
-
197
- References
198
- ----------
199
- .. [1] Abbott, L. F., Varela, J. A., Sen, K., & Nelson, S. B. (1997). Synaptic
200
- depression and cortical gain control. Science, 275(5297), 220-224.
201
- .. [2] Tsodyks, M. V., & Markram, H. (1997). The neural code between neocortical
202
- pyramidal neurons depends on neurotransmitter release probability.
203
- Proceedings of the National Academy of Sciences, 94(2), 719-723.
204
- """
205
- __module__ = 'brainstate.nn'
206
-
207
- def __init__(
208
- self,
209
- in_size: Size,
210
- name: Optional[str] = None,
211
- # synapse parameters
212
- tau: ArrayLike = 200. * u.ms,
213
- U: ArrayLike = 0.07,
214
- ):
215
- super().__init__(name=name, in_size=in_size)
216
-
217
- # parameters
218
- self.tau = init.param(tau, self.varshape)
219
- self.U = init.param(U, self.varshape)
220
-
221
- def init_state(self, batch_size: int = None, **kwargs):
222
- self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
223
-
224
- def reset_state(self, batch_size: int = None, **kwargs):
225
- self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
226
-
227
- def update(self, pre_spike):
228
- x = exp_euler_step(lambda x: (1 - x) / self.tau, self.x.value)
229
-
230
- # --- original code:
231
- # self.x.value = bm.where(pre_spike, x - self.U * self.x, x)
232
-
233
- # --- simplified code:
234
- self.x.value = x - pre_spike * self.U * self.x.value
235
-
236
- return self.x.value * pre_spike