brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,315 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Optional
21
+
22
+ import brainunit as u
23
+
24
+ from brainstate import init, environ
25
+ from brainstate._state import ShortTermState, HiddenState
26
+ from brainstate.mixin import AlignPost
27
+ from brainstate.nn._dynamics._dynamics_base import Dynamics
28
+ from brainstate.nn._exp_euler import exp_euler_step
29
+ from brainstate.typing import ArrayLike, Size
30
+
31
+ __all__ = [
32
+ 'Synapse', 'Expon', 'STP', 'STD', 'AMPA', 'GABAa',
33
+ ]
34
+
35
+
36
+ class Synapse(Dynamics):
37
+ """
38
+ Base class for synapse dynamics.
39
+ """
40
+ __module__ = 'brainstate.nn'
41
+
42
+
43
+ class Expon(Synapse, AlignPost):
44
+ r"""Exponential decay synapse model.
45
+
46
+ Args:
47
+ tau: float. The time constant of decay. [ms]
48
+ %s
49
+ """
50
+ __module__ = 'brainstate.nn'
51
+
52
+ def __init__(
53
+ self,
54
+ in_size: Size,
55
+ name: Optional[str] = None,
56
+ tau: ArrayLike = 8.0 * u.ms,
57
+ g_initializer: ArrayLike = init.ZeroInit(unit=u.mS),
58
+ ):
59
+ super().__init__(name=name, in_size=in_size)
60
+
61
+ # parameters
62
+ self.tau = init.param(tau, self.varshape)
63
+ self.g_initializer = g_initializer
64
+
65
+ def init_state(self, batch_size: int = None, **kwargs):
66
+ self.g = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
67
+
68
+ def reset_state(self, batch_size: int = None, **kwargs):
69
+ self.g.value = init.param(self.g_initializer, self.varshape, batch_size)
70
+
71
+ def update(self, x=None):
72
+ g = exp_euler_step(lambda g: self.sum_current_inputs(-g) / self.tau, self.g.value)
73
+ self.g.value = self.sum_delta_inputs(g)
74
+ if x is not None: self.g.value += x
75
+ return self.g.value
76
+
77
+
78
+ class STP(Synapse):
79
+ r"""Synaptic output with short-term plasticity.
80
+
81
+ %s
82
+
83
+ Args:
84
+ tau_f: float, ArrayType, Callable. The time constant of short-term facilitation.
85
+ tau_d: float, ArrayType, Callable. The time constant of short-term depression.
86
+ U: float, ArrayType, Callable. The fraction of resources used per action potential.
87
+ %s
88
+ """
89
+ __module__ = 'brainstate.nn'
90
+
91
+ def __init__(
92
+ self,
93
+ in_size: Size,
94
+ name: Optional[str] = None,
95
+ U: ArrayLike = 0.15,
96
+ tau_f: ArrayLike = 1500. * u.ms,
97
+ tau_d: ArrayLike = 200. * u.ms,
98
+ ):
99
+ super().__init__(name=name, in_size=in_size)
100
+
101
+ # parameters
102
+ self.tau_f = init.param(tau_f, self.varshape)
103
+ self.tau_d = init.param(tau_d, self.varshape)
104
+ self.U = init.param(U, self.varshape)
105
+
106
+ def init_state(self, batch_size: int = None, **kwargs):
107
+ self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
108
+ self.u = HiddenState(init.param(init.Constant(self.U), self.varshape, batch_size))
109
+
110
+ def reset_state(self, batch_size: int = None, **kwargs):
111
+ self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
112
+ self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
113
+
114
+ def update(self, pre_spike):
115
+ u = exp_euler_step(lambda u: self.U - u / self.tau_f, self.u.value)
116
+ x = exp_euler_step(lambda x: (1 - x) / self.tau_d, self.x.value)
117
+
118
+ # --- original code:
119
+ # if pre_spike.dtype == jax.numpy.bool_:
120
+ # u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
121
+ # x = bm.where(pre_spike, x - u * self.x, x)
122
+ # else:
123
+ # u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u
124
+ # x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
125
+
126
+ # --- simplified code:
127
+ u = u + pre_spike * self.U * (1 - self.u.value)
128
+ x = x - pre_spike * u * self.x.value
129
+
130
+ self.u.value = u
131
+ self.x.value = x
132
+ return u * x * pre_spike
133
+
134
+
135
+ class STD(Synapse):
136
+ r"""Synaptic output with short-term depression.
137
+
138
+ %s
139
+
140
+ Args:
141
+ tau: float, ArrayType, Callable. The time constant of recovery of the synaptic vesicles.
142
+ U: float, ArrayType, Callable. The fraction of resources used per action potential.
143
+ %s
144
+ """
145
+ __module__ = 'brainstate.nn'
146
+
147
+ def __init__(
148
+ self,
149
+ in_size: Size,
150
+ name: Optional[str] = None,
151
+ # synapse parameters
152
+ tau: ArrayLike = 200. * u.ms,
153
+ U: ArrayLike = 0.07,
154
+ ):
155
+ super().__init__(name=name, in_size=in_size)
156
+
157
+ # parameters
158
+ self.tau = init.param(tau, self.varshape)
159
+ self.U = init.param(U, self.varshape)
160
+
161
+ def init_state(self, batch_size: int = None, **kwargs):
162
+ self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
163
+
164
+ def reset_state(self, batch_size: int = None, **kwargs):
165
+ self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
166
+
167
+ def update(self, pre_spike):
168
+ x = exp_euler_step(lambda x: (1 - x) / self.tau, self.x.value)
169
+
170
+ # --- original code:
171
+ # self.x.value = bm.where(pre_spike, x - self.U * self.x, x)
172
+
173
+ # --- simplified code:
174
+ self.x.value = x - pre_spike * self.U * self.x.value
175
+
176
+ return self.x.value * pre_spike
177
+
178
+
179
+ class AMPA(Synapse):
180
+ r"""AMPA synapse model.
181
+
182
+ **Model Descriptions**
183
+
184
+ AMPA receptor is an ionotropic receptor, which is an ion channel.
185
+ When it is bound by neurotransmitters, it will immediately open the
186
+ ion channel, causing the change of membrane potential of postsynaptic neurons.
187
+
188
+ A classical model is to use the Markov process to model ion channel switch.
189
+ Here :math:`g` represents the probability of channel opening, :math:`1-g`
190
+ represents the probability of ion channel closing, and :math:`\alpha` and
191
+ :math:`\beta` are the transition probability. Because neurotransmitters can
192
+ open ion channels, the transfer probability from :math:`1-g` to :math:`g`
193
+ is affected by the concentration of neurotransmitters. We denote the concentration
194
+ of neurotransmitters as :math:`[T]` and get the following Markov process.
195
+
196
+ .. image:: ../../_static/synapse_markov.png
197
+ :align: center
198
+
199
+ We obtained the following formula when describing the process by a differential equation.
200
+
201
+ .. math::
202
+
203
+ \frac{ds}{dt} =\alpha[T](1-g)-\beta g
204
+
205
+ where :math:`\alpha [T]` denotes the transition probability from state :math:`(1-g)`
206
+ to state :math:`(g)`; and :math:`\beta` represents the transition probability of
207
+ the other direction. :math:`\alpha` is the binding constant. :math:`\beta` is the
208
+ unbinding constant. :math:`[T]` is the neurotransmitter concentration, and
209
+ has the duration of 0.5 ms.
210
+
211
+ Moreover, the post-synaptic current on the post-synaptic neuron is formulated as
212
+
213
+ .. math::
214
+
215
+ I_{syn} = g_{max} g (V-E)
216
+
217
+ where :math:`g_{max}` is the maximum conductance, and `E` is the reverse potential.
218
+
219
+ This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example:
220
+
221
+ .. [1] Vijayan S, Kopell N J. Thalamic model of awake alpha oscillations
222
+ and implications for stimulus processing[J]. Proceedings of the
223
+ National Academy of Sciences, 2012, 109(45): 18553-18558.
224
+
225
+ Args:
226
+ alpha: float, ArrayType, Callable. Binding constant. [ms^-1 mM^-1]
227
+ beta: float, ArrayType, Callable. Unbinding constant. [ms^-1 mM^-1]
228
+ T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by
229
+ a pre-synaptic spike. Default 0.5 [mM].
230
+ T_dur: float, ArrayType, Callable. Transmitter concentration duration time after being triggered. Default 1 [ms]
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ in_size: Size,
236
+ name: Optional[str] = None,
237
+ alpha: ArrayLike = 0.98 / (u.ms * u.mM),
238
+ beta: ArrayLike = 0.18 / u.ms,
239
+ T: ArrayLike = 0.5 * u.mM,
240
+ T_dur: ArrayLike = 0.5 * u.ms,
241
+ g_initializer: ArrayLike = init.ZeroInit(),
242
+ ):
243
+ super().__init__(name=name, in_size=in_size)
244
+
245
+ # parameters
246
+ self.alpha = init.param(alpha, self.varshape)
247
+ self.beta = init.param(beta, self.varshape)
248
+ self.T = init.param(T, self.varshape)
249
+ self.T_duration = init.param(T_dur, self.varshape)
250
+ self.g_initializer = g_initializer
251
+
252
+ def init_state(self, batch_size=None):
253
+ self.g = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
254
+ self.spike_arrival_time = ShortTermState(init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_size))
255
+
256
+ def reset_state(self, batch_or_mode=None, **kwargs):
257
+ self.g.value = init.param(self.g_initializer, self.varshape, batch_or_mode)
258
+ self.spike_arrival_time.value = init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_or_mode)
259
+
260
+ def dg(self, g, t, TT):
261
+ return self.alpha * TT * (1 - g) - self.beta * g
262
+
263
+ def update(self, pre_spike):
264
+ t = environ.get('t')
265
+ self.spike_arrival_time.value = u.math.where(pre_spike, t, self.spike_arrival_time.value)
266
+ TT = ((t - self.spike_arrival_time.value) < self.T_duration) * self.T
267
+ self.g.value = exp_euler_step(self.dg, self.g.value, t, TT)
268
+ return self.g.value
269
+
270
+
271
+ class GABAa(AMPA):
272
+ r"""GABAa synapse model.
273
+
274
+ **Model Descriptions**
275
+
276
+ GABAa synapse model has the same equation with the `AMPA synapse <./brainmodels.synapses.AMPA.rst>`_,
277
+
278
+ .. math::
279
+
280
+ \frac{d g}{d t}&=\alpha[T](1-g) - \beta g \\
281
+ I_{syn}&= - g_{max} g (V - E)
282
+
283
+ but with the difference of:
284
+
285
+ - Reversal potential of synapse :math:`E` is usually low, typically -80. mV
286
+ - Activating rate constant :math:`\alpha=0.53`
287
+ - De-activating rate constant :math:`\beta=0.18`
288
+ - Transmitter concentration :math:`[T]=1\,\mu ho(\mu S)` when synapse is
289
+ triggered by a pre-synaptic spike, with the duration of 1. ms.
290
+
291
+ This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example:
292
+
293
+ .. [1] Destexhe, Alain, and Denis Paré. "Impact of network activity
294
+ on the integrative properties of neocortical pyramidal neurons
295
+ in vivo." Journal of neurophysiology 81.4 (1999): 1531-1547.
296
+
297
+ Args:
298
+ alpha: float, ArrayType, Callable. Binding constant. Default 0.062
299
+ beta: float, ArrayType, Callable. Unbinding constant. Default 3.57
300
+ T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by
301
+ a pre-synaptic spike.. Default 1 [mM].
302
+ T_dur: float, ArrayType, Callable. Transmitter concentration duration time
303
+ after being triggered. Default 1 [ms]
304
+ """
305
+
306
+ def __init__(
307
+ self,
308
+ in_size: Size,
309
+ name: Optional[str] = None,
310
+ alpha: ArrayLike = 0.53 / (u.ms * u.mM),
311
+ beta: ArrayLike = 0.18 / u.ms,
312
+ T: ArrayLike = 1.0 * u.mM,
313
+ T_dur: ArrayLike = 1.0 * u.ms,
314
+ ):
315
+ super().__init__(alpha=alpha, beta=beta, T=T, T_dur=T_dur, name=name, in_size=in_size)
@@ -0,0 +1,132 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import unittest
19
+
20
+ import brainunit as u
21
+ import jax.numpy as jnp
22
+ import pytest
23
+
24
+ import brainstate as bst
25
+ from brainstate.nn import Expon, STP, STD
26
+
27
+
28
+ class TestSynapse(unittest.TestCase):
29
+ def setUp(self):
30
+ self.in_size = 10
31
+ self.batch_size = 5
32
+ self.time_steps = 100
33
+
34
+ def generate_input(self):
35
+ return bst.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mS
36
+
37
+ def test_expon_synapse(self):
38
+ tau = 20.0 * u.ms
39
+ synapse = Expon(self.in_size, tau=tau)
40
+ inputs = self.generate_input()
41
+
42
+ # Test initialization
43
+ self.assertEqual(synapse.in_size, (self.in_size,))
44
+ self.assertEqual(synapse.out_size, (self.in_size,))
45
+ self.assertEqual(synapse.tau, tau)
46
+
47
+ # Test forward pass
48
+ state = synapse.init_state(self.batch_size)
49
+ call = bst.compile.jit(synapse)
50
+ with bst.environ.context(dt=0.1 * u.ms):
51
+ for t in range(self.time_steps):
52
+ out = call(inputs[t])
53
+ self.assertEqual(out.shape, (self.batch_size, self.in_size))
54
+
55
+ # Test exponential decay
56
+ constant_input = jnp.ones((self.batch_size, self.in_size)) * u.mS
57
+ out1 = call(constant_input)
58
+ out2 = call(constant_input)
59
+ self.assertTrue(jnp.all(out2 > out1)) # Output should increase with constant input
60
+
61
+ @pytest.mark.skip(reason="Not implemented yet")
62
+ def test_stp_synapse(self):
63
+ tau_d = 200.0 * u.ms
64
+ tau_f = 20.0 * u.ms
65
+ U = 0.2
66
+ synapse = STP(self.in_size, tau_d=tau_d, tau_f=tau_f, U=U)
67
+ inputs = self.generate_input()
68
+
69
+ # Test initialization
70
+ self.assertEqual(synapse.in_size, (self.in_size,))
71
+ self.assertEqual(synapse.out_size, (self.in_size,))
72
+ self.assertEqual(synapse.tau_d, tau_d)
73
+ self.assertEqual(synapse.tau_f, tau_f)
74
+ self.assertEqual(synapse.U, U)
75
+
76
+ # Test forward pass
77
+ state = synapse.init_state(self.batch_size)
78
+ call = bst.compile.jit(synapse)
79
+ for t in range(self.time_steps):
80
+ out = call(inputs[t])
81
+ self.assertEqual(out.shape, (self.batch_size, self.in_size))
82
+
83
+ # Test short-term plasticity
84
+ constant_input = jnp.ones((self.batch_size, self.in_size)) * u.mS
85
+ out1 = call(constant_input)
86
+ out2 = call(constant_input)
87
+ self.assertTrue(jnp.any(out2 != out1)) # Output should change due to STP
88
+
89
+ @pytest.mark.skip(reason="Not implemented yet")
90
+ def test_std_synapse(self):
91
+ tau = 200.0
92
+ U = 0.2
93
+ synapse = STD(self.in_size, tau=tau, U=U)
94
+ inputs = self.generate_input()
95
+
96
+ # Test initialization
97
+ self.assertEqual(synapse.in_size, (self.in_size,))
98
+ self.assertEqual(synapse.out_size, (self.in_size,))
99
+ self.assertEqual(synapse.tau, tau)
100
+ self.assertEqual(synapse.U, U)
101
+
102
+ # Test forward pass
103
+ state = synapse.init_state(self.batch_size)
104
+ for t in range(self.time_steps):
105
+ out = synapse(inputs[t])
106
+ self.assertEqual(out.shape, (self.batch_size, self.in_size))
107
+
108
+ # Test short-term depression
109
+ constant_input = jnp.ones((self.batch_size, self.in_size))
110
+ out1 = synapse(constant_input)
111
+ out2 = synapse(constant_input)
112
+ self.assertTrue(jnp.all(out2 < out1)) # Output should decrease due to STD
113
+
114
+ def test_keep_size(self):
115
+ in_size = (2, 3)
116
+ for SynapseClass in [Expon, ]:
117
+ synapse = SynapseClass(in_size)
118
+ self.assertEqual(synapse.in_size, in_size)
119
+ self.assertEqual(synapse.out_size, in_size)
120
+
121
+ inputs = bst.random.randn(self.time_steps, self.batch_size, *in_size) * u.mS
122
+ state = synapse.init_state(self.batch_size)
123
+ call = bst.compile.jit(synapse)
124
+ with bst.environ.context(dt=0.1 * u.ms):
125
+ for t in range(self.time_steps):
126
+ out = call(inputs[t])
127
+ self.assertEqual(out.shape, (self.batch_size, *in_size))
128
+
129
+
130
+ if __name__ == '__main__':
131
+ with bst.environ.context(dt=0.1):
132
+ unittest.main()
@@ -0,0 +1,154 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from __future__ import annotations
16
+
17
+ from typing import Union, Optional, Sequence, Callable
18
+
19
+ import brainunit as u
20
+
21
+ from brainstate import environ, init, random
22
+ from brainstate._state import ShortTermState
23
+ from brainstate.compile import while_loop
24
+ from brainstate.nn._dynamics._dynamics_base import Dynamics
25
+ from brainstate.typing import ArrayLike, Size, DTypeLike
26
+
27
+ __all__ = [
28
+ 'SpikeTime',
29
+ 'PoissonSpike',
30
+ 'PoissonEncoder',
31
+ ]
32
+
33
+
34
+ class SpikeTime(Dynamics):
35
+ """The input neuron group characterized by spikes emitting at given times.
36
+
37
+ >>> # Get 2 neurons, firing spikes at 10 ms and 20 ms.
38
+ >>> SpikeTime(2, times=[10, 20])
39
+ >>> # or
40
+ >>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms.
41
+ >>> SpikeTime(2, times=[10, 20], indices=[0, 0])
42
+ >>> # or
43
+ >>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms.
44
+ >>> SpikeTime(2, times=[10, 20, 30], indices=[0, 1, 0])
45
+ >>> # or
46
+ >>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire;
47
+ >>> # at 30 ms, neuron 1 fires.
48
+ >>> SpikeTime(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])
49
+
50
+ Parameters
51
+ ----------
52
+ in_size : int, tuple, list
53
+ The neuron group geometry.
54
+ indices : list, tuple, ArrayType
55
+ The neuron indices at each time point to emit spikes.
56
+ times : list, tuple, ArrayType
57
+ The time points which generate the spikes.
58
+ name : str, optional
59
+ The name of the dynamic system.
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ in_size: Size,
65
+ indices: Union[Sequence, ArrayLike],
66
+ times: Union[Sequence, ArrayLike],
67
+ spk_type: DTypeLike = bool,
68
+ name: Optional[str] = None,
69
+ need_sort: bool = True,
70
+ ):
71
+ super().__init__(in_size=in_size, name=name)
72
+
73
+ # parameters
74
+ if len(indices) != len(times):
75
+ raise ValueError(f'The length of "indices" and "times" must be the same. '
76
+ f'However, we got {len(indices)} != {len(times)}.')
77
+ self.num_times = len(times)
78
+ self.spk_type = spk_type
79
+
80
+ # data about times and indices
81
+ self.times = u.math.asarray(times)
82
+ self.indices = u.math.asarray(indices, dtype=environ.ditype())
83
+ if need_sort:
84
+ sort_idx = u.math.argsort(self.times)
85
+ self.indices = self.indices[sort_idx]
86
+ self.times = self.times[sort_idx]
87
+
88
+ def init_state(self, *args, **kwargs):
89
+ self.i = ShortTermState(-1)
90
+
91
+ def reset_state(self, batch_size=None, **kwargs):
92
+ self.i.value = -1
93
+
94
+ def update(self):
95
+ t = environ.get('t')
96
+
97
+ def _cond_fun(spikes):
98
+ i = self.i.value
99
+ return u.math.logical_and(i < self.num_times, t >= self.times[i])
100
+
101
+ def _body_fun(spikes):
102
+ i = self.i.value
103
+ spikes = spikes.at[..., self.indices[i]].set(True)
104
+ self.i.value += 1
105
+ return spikes
106
+
107
+ spike = u.math.zeros(self.varshape, dtype=self.spk_type)
108
+ spike = while_loop(_cond_fun, _body_fun, spike)
109
+ return spike
110
+
111
+
112
+ class PoissonSpike(Dynamics):
113
+ """
114
+ Poisson Neuron Group.
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ in_size: Size,
120
+ freqs: Union[ArrayLike, Callable],
121
+ spk_type: DTypeLike = bool,
122
+ name: Optional[str] = None,
123
+ ):
124
+ super().__init__(in_size=in_size, name=name)
125
+
126
+ self.spk_type = spk_type
127
+
128
+ # parameters
129
+ self.freqs = init.param(freqs, self.varshape, allow_none=False)
130
+
131
+ def update(self):
132
+ spikes = random.rand(self.varshape) <= (self.freqs * environ.get_dt())
133
+ spikes = u.math.asarray(spikes, dtype=self.spk_type)
134
+ return spikes
135
+
136
+
137
+ class PoissonEncoder(Dynamics):
138
+ """
139
+ Poisson Neuron Group.
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ in_size: Size,
145
+ spk_type: DTypeLike = bool,
146
+ name: Optional[str] = None,
147
+ ):
148
+ super().__init__(in_size=in_size, name=name)
149
+ self.spk_type = spk_type
150
+
151
+ def update(self, freqs: ArrayLike):
152
+ spikes = random.rand(*self.varshape) <= (freqs * environ.get_dt())
153
+ spikes = u.math.asarray(spikes, dtype=self.spk_type)
154
+ return spikes
@@ -13,13 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from __future__ import annotations
16
17
 
17
- from .csr import *
18
- from .csr import __all__ as __all_csr
19
- from .fixed_probability import *
20
- from .fixed_probability import __all__ as __all_fixed_probability
21
- from .linear import *
22
- from .linear import __all__ as __all_linear
18
+ from brainstate.nn._dynamics._dynamics_base import Projection
23
19
 
24
- __all__ = __all_fixed_probability + __all_linear + __all_csr
25
- del __all_fixed_probability, __all_linear, __all_csr
20
+ __all__ = [
21
+ ]
22
+
23
+
24
+ class ExponentialSynapse(Projection):
25
+ pass