brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1356 -1321
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,320 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Optional
21
+
22
+ import brainunit as u
23
+
24
+ from brainstate import init, environ
25
+ from brainstate._state import ShortTermState, HiddenState
26
+ from brainstate.mixin import AlignPost
27
+ from brainstate.nn._dynamics._dynamics_base import Dynamics
28
+ from brainstate.nn._exp_euler import exp_euler_step
29
+ from brainstate.typing import ArrayLike, Size
30
+
31
+ __all__ = [
32
+ 'Synapse', 'Expon', 'STP', 'STD', 'AMPA', 'GABAa',
33
+ ]
34
+
35
+
36
+ class Synapse(Dynamics):
37
+ """
38
+ Base class for synapse dynamics.
39
+ """
40
+ __module__ = 'brainstate.nn'
41
+
42
+
43
+ class Expon(Synapse, AlignPost):
44
+ r"""Exponential decay synapse model.
45
+
46
+ Args:
47
+ tau: float. The time constant of decay. [ms]
48
+ %s
49
+ """
50
+ __module__ = 'brainstate.nn'
51
+
52
+ def __init__(
53
+ self,
54
+ in_size: Size,
55
+ name: Optional[str] = None,
56
+ tau: ArrayLike = 8.0 * u.ms,
57
+ g_initializer: ArrayLike = init.ZeroInit(unit=u.mS),
58
+ ):
59
+ super().__init__(name=name, in_size=in_size)
60
+
61
+ # parameters
62
+ self.tau = init.param(tau, self.varshape)
63
+ self.g_initializer = g_initializer
64
+
65
+ def init_state(self, batch_size: int = None, **kwargs):
66
+ self.g = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
67
+
68
+ def reset_state(self, batch_size: int = None, **kwargs):
69
+ self.g.value = init.param(self.g_initializer, self.varshape, batch_size)
70
+
71
+ def update(self, x=None):
72
+ g = exp_euler_step(lambda g: self.sum_current_inputs(-g) / self.tau, self.g.value)
73
+ self.g.value = self.sum_delta_inputs(g)
74
+ if x is not None: self.g.value += x
75
+ return self.g.value
76
+
77
+
78
+ class STP(Synapse):
79
+ r"""Synaptic output with short-term plasticity.
80
+
81
+ %s
82
+
83
+ Args:
84
+ tau_f: float, ArrayType, Callable. The time constant of short-term facilitation.
85
+ tau_d: float, ArrayType, Callable. The time constant of short-term depression.
86
+ U: float, ArrayType, Callable. The fraction of resources used per action potential.
87
+ %s
88
+ """
89
+ __module__ = 'brainstate.nn'
90
+
91
+ def __init__(
92
+ self,
93
+ in_size: Size,
94
+ name: Optional[str] = None,
95
+ U: ArrayLike = 0.15,
96
+ tau_f: ArrayLike = 1500. * u.ms,
97
+ tau_d: ArrayLike = 200. * u.ms,
98
+ ):
99
+ super().__init__(name=name, in_size=in_size)
100
+
101
+ # parameters
102
+ self.tau_f = init.param(tau_f, self.varshape)
103
+ self.tau_d = init.param(tau_d, self.varshape)
104
+ self.U = init.param(U, self.varshape)
105
+
106
+ def init_state(self, batch_size: int = None, **kwargs):
107
+ self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
108
+ self.u = HiddenState(init.param(init.Constant(self.U), self.varshape, batch_size))
109
+
110
+ def reset_state(self, batch_size: int = None, **kwargs):
111
+ self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
112
+ self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
113
+
114
+ def update(self, pre_spike):
115
+ du = lambda u: self.U - u / self.tau_f
116
+ dx = lambda x: (1 - x) / self.tau_d
117
+ u = exp_euler_step(du, self.u.value)
118
+ x = exp_euler_step(dx, self.x.value)
119
+
120
+ # --- original code:
121
+ # if pre_spike.dtype == jax.numpy.bool_:
122
+ # u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
123
+ # x = bm.where(pre_spike, x - u * self.x, x)
124
+ # else:
125
+ # u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u
126
+ # x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
127
+
128
+ # --- simplified code:
129
+ u = u + pre_spike * self.U * (1 - self.u.value)
130
+ x = x - pre_spike * u * self.x.value
131
+
132
+ self.u.value = u
133
+ self.x.value = x
134
+ return u * x
135
+
136
+
137
+ class STD(Synapse):
138
+ r"""Synaptic output with short-term depression.
139
+
140
+ %s
141
+
142
+ Args:
143
+ tau: float, ArrayType, Callable. The time constant of recovery of the synaptic vesicles.
144
+ U: float, ArrayType, Callable. The fraction of resources used per action potential.
145
+ %s
146
+ """
147
+ __module__ = 'brainstate.nn'
148
+
149
+ def __init__(
150
+ self,
151
+ in_size: Size,
152
+ name: Optional[str] = None,
153
+ # synapse parameters
154
+ tau: ArrayLike = 200. * u.ms,
155
+ U: ArrayLike = 0.07,
156
+ ):
157
+ super().__init__(name=name, in_size=in_size)
158
+
159
+ # parameters
160
+ self.tau = init.param(tau, self.varshape)
161
+ self.U = init.param(U, self.varshape)
162
+
163
+ def init_state(self, batch_size: int = None, **kwargs):
164
+ self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
165
+
166
+ def reset_state(self, batch_size: int = None, **kwargs):
167
+ self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
168
+
169
+ def update(self, pre_spike):
170
+ dx = lambda x: (1 - x) / self.tau
171
+ x = exp_euler_step(dx, self.x.value)
172
+
173
+ # --- original code:
174
+ # self.x.value = bm.where(pre_spike, x - self.U * self.x, x)
175
+
176
+ # --- simplified code:
177
+ self.x.value = x - pre_spike * self.U * self.x.value
178
+
179
+ return self.x.value
180
+
181
+
182
+ class AMPA(Synapse):
183
+ r"""AMPA synapse model.
184
+
185
+ **Model Descriptions**
186
+
187
+ AMPA receptor is an ionotropic receptor, which is an ion channel.
188
+ When it is bound by neurotransmitters, it will immediately open the
189
+ ion channel, causing the change of membrane potential of postsynaptic neurons.
190
+
191
+ A classical model is to use the Markov process to model ion channel switch.
192
+ Here :math:`g` represents the probability of channel opening, :math:`1-g`
193
+ represents the probability of ion channel closing, and :math:`\alpha` and
194
+ :math:`\beta` are the transition probability. Because neurotransmitters can
195
+ open ion channels, the transfer probability from :math:`1-g` to :math:`g`
196
+ is affected by the concentration of neurotransmitters. We denote the concentration
197
+ of neurotransmitters as :math:`[T]` and get the following Markov process.
198
+
199
+ .. image:: ../../_static/synapse_markov.png
200
+ :align: center
201
+
202
+ We obtained the following formula when describing the process by a differential equation.
203
+
204
+ .. math::
205
+
206
+ \frac{ds}{dt} =\alpha[T](1-g)-\beta g
207
+
208
+ where :math:`\alpha [T]` denotes the transition probability from state :math:`(1-g)`
209
+ to state :math:`(g)`; and :math:`\beta` represents the transition probability of
210
+ the other direction. :math:`\alpha` is the binding constant. :math:`\beta` is the
211
+ unbinding constant. :math:`[T]` is the neurotransmitter concentration, and
212
+ has the duration of 0.5 ms.
213
+
214
+ Moreover, the post-synaptic current on the post-synaptic neuron is formulated as
215
+
216
+ .. math::
217
+
218
+ I_{syn} = g_{max} g (V-E)
219
+
220
+ where :math:`g_{max}` is the maximum conductance, and `E` is the reverse potential.
221
+
222
+ This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example:
223
+
224
+ .. [1] Vijayan S, Kopell N J. Thalamic model of awake alpha oscillations
225
+ and implications for stimulus processing[J]. Proceedings of the
226
+ National Academy of Sciences, 2012, 109(45): 18553-18558.
227
+
228
+ Args:
229
+ alpha: float, ArrayType, Callable. Binding constant. [ms^-1 mM^-1]
230
+ beta: float, ArrayType, Callable. Unbinding constant. [ms^-1 mM^-1]
231
+ T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by
232
+ a pre-synaptic spike. Default 0.5 [mM].
233
+ T_dur: float, ArrayType, Callable. Transmitter concentration duration time after being triggered. Default 1 [ms]
234
+ """
235
+
236
+ def __init__(
237
+ self,
238
+ in_size: Size,
239
+ name: Optional[str] = None,
240
+ alpha: ArrayLike = 0.98 / (u.ms * u.mM),
241
+ beta: ArrayLike = 0.18 / u.ms,
242
+ T: ArrayLike = 0.5 * u.mM,
243
+ T_dur: ArrayLike = 0.5 * u.ms,
244
+ g_initializer: ArrayLike = init.ZeroInit(),
245
+ ):
246
+ super().__init__(name=name, in_size=in_size)
247
+
248
+ # parameters
249
+ self.alpha = init.param(alpha, self.varshape)
250
+ self.beta = init.param(beta, self.varshape)
251
+ self.T = init.param(T, self.varshape)
252
+ self.T_duration = init.param(T_dur, self.varshape)
253
+ self.g_initializer = g_initializer
254
+
255
+ def init_state(self, batch_size=None):
256
+ self.g = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
257
+ self.spike_arrival_time = ShortTermState(init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_size))
258
+
259
+ def reset_state(self, batch_or_mode=None, **kwargs):
260
+ self.g.value = init.param(self.g_initializer, self.varshape, batch_or_mode)
261
+ self.spike_arrival_time.value = init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_or_mode)
262
+
263
+ def dg(self, g, t, TT):
264
+ return self.alpha * TT * (1 - g) - self.beta * g
265
+
266
+ def update(self, pre_spike):
267
+ t = environ.get('t')
268
+ self.spike_arrival_time.value = u.math.where(pre_spike, t, self.spike_arrival_time.value)
269
+ TT = ((t - self.spike_arrival_time.value) < self.T_duration) * self.T
270
+ self.g.value = exp_euler_step(self.dg, self.g.value, t, TT)
271
+ return self.g.value
272
+
273
+
274
+ class GABAa(AMPA):
275
+ r"""GABAa synapse model.
276
+
277
+ **Model Descriptions**
278
+
279
+ GABAa synapse model has the same equation with the `AMPA synapse <./brainmodels.synapses.AMPA.rst>`_,
280
+
281
+ .. math::
282
+
283
+ \frac{d g}{d t}&=\alpha[T](1-g) - \beta g \\
284
+ I_{syn}&= - g_{max} g (V - E)
285
+
286
+ but with the difference of:
287
+
288
+ - Reversal potential of synapse :math:`E` is usually low, typically -80. mV
289
+ - Activating rate constant :math:`\alpha=0.53`
290
+ - De-activating rate constant :math:`\beta=0.18`
291
+ - Transmitter concentration :math:`[T]=1\,\mu ho(\mu S)` when synapse is
292
+ triggered by a pre-synaptic spike, with the duration of 1. ms.
293
+
294
+ This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example:
295
+
296
+ .. [1] Destexhe, Alain, and Denis Paré. "Impact of network activity
297
+ on the integrative properties of neocortical pyramidal neurons
298
+ in vivo." Journal of neurophysiology 81.4 (1999): 1531-1547.
299
+
300
+ Args:
301
+ alpha: float, ArrayType, Callable. Binding constant. Default 0.062
302
+ beta: float, ArrayType, Callable. Unbinding constant. Default 3.57
303
+ T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by
304
+ a pre-synaptic spike.. Default 1 [mM].
305
+ T_dur: float, ArrayType, Callable. Transmitter concentration duration time
306
+ after being triggered. Default 1 [ms]
307
+ """
308
+
309
+ def __init__(
310
+ self,
311
+ in_size: Size,
312
+ name: Optional[str] = None,
313
+ alpha: ArrayLike = 0.53 / (u.ms * u.mM),
314
+ beta: ArrayLike = 0.18 / u.ms,
315
+ T: ArrayLike = 1.0 * u.mM,
316
+ T_dur: ArrayLike = 1.0 * u.ms,
317
+ ):
318
+ super().__init__(alpha=alpha, beta=beta, T=T,
319
+ T_dur=T_dur, name=name,
320
+ in_size=in_size)
@@ -0,0 +1,132 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import unittest
19
+
20
+ import brainunit as u
21
+ import jax.numpy as jnp
22
+ import pytest
23
+
24
+ import brainstate as bst
25
+ from brainstate.nn import Expon, STP, STD
26
+
27
+
28
+ class TestSynapse(unittest.TestCase):
29
+ def setUp(self):
30
+ self.in_size = 10
31
+ self.batch_size = 5
32
+ self.time_steps = 100
33
+
34
+ def generate_input(self):
35
+ return bst.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mS
36
+
37
+ def test_expon_synapse(self):
38
+ tau = 20.0 * u.ms
39
+ synapse = Expon(self.in_size, tau=tau)
40
+ inputs = self.generate_input()
41
+
42
+ # Test initialization
43
+ self.assertEqual(synapse.in_size, (self.in_size,))
44
+ self.assertEqual(synapse.out_size, (self.in_size,))
45
+ self.assertEqual(synapse.tau, tau)
46
+
47
+ # Test forward pass
48
+ state = synapse.init_state(self.batch_size)
49
+ call = bst.compile.jit(synapse)
50
+ with bst.environ.context(dt=0.1 * u.ms):
51
+ for t in range(self.time_steps):
52
+ out = call(inputs[t])
53
+ self.assertEqual(out.shape, (self.batch_size, self.in_size))
54
+
55
+ # Test exponential decay
56
+ constant_input = jnp.ones((self.batch_size, self.in_size)) * u.mS
57
+ out1 = call(constant_input)
58
+ out2 = call(constant_input)
59
+ self.assertTrue(jnp.all(out2 > out1)) # Output should increase with constant input
60
+
61
+ @pytest.mark.skip(reason="Not implemented yet")
62
+ def test_stp_synapse(self):
63
+ tau_d = 200.0 * u.ms
64
+ tau_f = 20.0 * u.ms
65
+ U = 0.2
66
+ synapse = STP(self.in_size, tau_d=tau_d, tau_f=tau_f, U=U)
67
+ inputs = self.generate_input()
68
+
69
+ # Test initialization
70
+ self.assertEqual(synapse.in_size, (self.in_size,))
71
+ self.assertEqual(synapse.out_size, (self.in_size,))
72
+ self.assertEqual(synapse.tau_d, tau_d)
73
+ self.assertEqual(synapse.tau_f, tau_f)
74
+ self.assertEqual(synapse.U, U)
75
+
76
+ # Test forward pass
77
+ state = synapse.init_state(self.batch_size)
78
+ call = bst.compile.jit(synapse)
79
+ for t in range(self.time_steps):
80
+ out = call(inputs[t])
81
+ self.assertEqual(out.shape, (self.batch_size, self.in_size))
82
+
83
+ # Test short-term plasticity
84
+ constant_input = jnp.ones((self.batch_size, self.in_size)) * u.mS
85
+ out1 = call(constant_input)
86
+ out2 = call(constant_input)
87
+ self.assertTrue(jnp.any(out2 != out1)) # Output should change due to STP
88
+
89
+ @pytest.mark.skip(reason="Not implemented yet")
90
+ def test_std_synapse(self):
91
+ tau = 200.0
92
+ U = 0.2
93
+ synapse = STD(self.in_size, tau=tau, U=U)
94
+ inputs = self.generate_input()
95
+
96
+ # Test initialization
97
+ self.assertEqual(synapse.in_size, (self.in_size,))
98
+ self.assertEqual(synapse.out_size, (self.in_size,))
99
+ self.assertEqual(synapse.tau, tau)
100
+ self.assertEqual(synapse.U, U)
101
+
102
+ # Test forward pass
103
+ state = synapse.init_state(self.batch_size)
104
+ for t in range(self.time_steps):
105
+ out = synapse(inputs[t])
106
+ self.assertEqual(out.shape, (self.batch_size, self.in_size))
107
+
108
+ # Test short-term depression
109
+ constant_input = jnp.ones((self.batch_size, self.in_size))
110
+ out1 = synapse(constant_input)
111
+ out2 = synapse(constant_input)
112
+ self.assertTrue(jnp.all(out2 < out1)) # Output should decrease due to STD
113
+
114
+ def test_keep_size(self):
115
+ in_size = (2, 3)
116
+ for SynapseClass in [Expon, ]:
117
+ synapse = SynapseClass(in_size)
118
+ self.assertEqual(synapse.in_size, in_size)
119
+ self.assertEqual(synapse.out_size, in_size)
120
+
121
+ inputs = bst.random.randn(self.time_steps, self.batch_size, *in_size) * u.mS
122
+ state = synapse.init_state(self.batch_size)
123
+ call = bst.compile.jit(synapse)
124
+ with bst.environ.context(dt=0.1 * u.ms):
125
+ for t in range(self.time_steps):
126
+ out = call(inputs[t])
127
+ self.assertEqual(out.shape, (self.batch_size, *in_size))
128
+
129
+
130
+ if __name__ == '__main__':
131
+ with bst.environ.context(dt=0.1):
132
+ unittest.main()
@@ -0,0 +1,154 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from __future__ import annotations
16
+
17
+ from typing import Union, Optional, Sequence, Callable
18
+
19
+ import brainunit as u
20
+
21
+ from brainstate import environ, init, random
22
+ from brainstate._state import ShortTermState
23
+ from brainstate.compile import while_loop
24
+ from brainstate.nn._dynamics._dynamics_base import Dynamics
25
+ from brainstate.typing import ArrayLike, Size, DTypeLike
26
+
27
+ __all__ = [
28
+ 'SpikeTime',
29
+ 'PoissonSpike',
30
+ 'PoissonEncoder',
31
+ ]
32
+
33
+
34
+ class SpikeTime(Dynamics):
35
+ """The input neuron group characterized by spikes emitting at given times.
36
+
37
+ >>> # Get 2 neurons, firing spikes at 10 ms and 20 ms.
38
+ >>> SpikeTime(2, times=[10, 20])
39
+ >>> # or
40
+ >>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms.
41
+ >>> SpikeTime(2, times=[10, 20], indices=[0, 0])
42
+ >>> # or
43
+ >>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms.
44
+ >>> SpikeTime(2, times=[10, 20, 30], indices=[0, 1, 0])
45
+ >>> # or
46
+ >>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire;
47
+ >>> # at 30 ms, neuron 1 fires.
48
+ >>> SpikeTime(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])
49
+
50
+ Parameters
51
+ ----------
52
+ in_size : int, tuple, list
53
+ The neuron group geometry.
54
+ indices : list, tuple, ArrayType
55
+ The neuron indices at each time point to emit spikes.
56
+ times : list, tuple, ArrayType
57
+ The time points which generate the spikes.
58
+ name : str, optional
59
+ The name of the dynamic system.
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ in_size: Size,
65
+ indices: Union[Sequence, ArrayLike],
66
+ times: Union[Sequence, ArrayLike],
67
+ spk_type: DTypeLike = bool,
68
+ name: Optional[str] = None,
69
+ need_sort: bool = True,
70
+ ):
71
+ super().__init__(in_size=in_size, name=name)
72
+
73
+ # parameters
74
+ if len(indices) != len(times):
75
+ raise ValueError(f'The length of "indices" and "times" must be the same. '
76
+ f'However, we got {len(indices)} != {len(times)}.')
77
+ self.num_times = len(times)
78
+ self.spk_type = spk_type
79
+
80
+ # data about times and indices
81
+ self.times = u.math.asarray(times)
82
+ self.indices = u.math.asarray(indices, dtype=environ.ditype())
83
+ if need_sort:
84
+ sort_idx = u.math.argsort(self.times)
85
+ self.indices = self.indices[sort_idx]
86
+ self.times = self.times[sort_idx]
87
+
88
+ def init_state(self, *args, **kwargs):
89
+ self.i = ShortTermState(-1)
90
+
91
+ def reset_state(self, batch_size=None, **kwargs):
92
+ self.i.value = -1
93
+
94
+ def update(self):
95
+ t = environ.get('t')
96
+
97
+ def _cond_fun(spikes):
98
+ i = self.i.value
99
+ return u.math.logical_and(i < self.num_times, t >= self.times[i])
100
+
101
+ def _body_fun(spikes):
102
+ i = self.i.value
103
+ spikes = spikes.at[..., self.indices[i]].set(True)
104
+ self.i.value += 1
105
+ return spikes
106
+
107
+ spike = u.math.zeros(self.varshape, dtype=self.spk_type)
108
+ spike = while_loop(_cond_fun, _body_fun, spike)
109
+ return spike
110
+
111
+
112
+ class PoissonSpike(Dynamics):
113
+ """
114
+ Poisson Neuron Group.
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ in_size: Size,
120
+ freqs: Union[ArrayLike, Callable],
121
+ spk_type: DTypeLike = bool,
122
+ name: Optional[str] = None,
123
+ ):
124
+ super().__init__(in_size=in_size, name=name)
125
+
126
+ self.spk_type = spk_type
127
+
128
+ # parameters
129
+ self.freqs = init.param(freqs, self.varshape, allow_none=False)
130
+
131
+ def update(self):
132
+ spikes = random.rand(self.varshape) <= (self.freqs * environ.get_dt())
133
+ spikes = u.math.asarray(spikes, dtype=self.spk_type)
134
+ return spikes
135
+
136
+
137
+ class PoissonEncoder(Dynamics):
138
+ """
139
+ Poisson Neuron Group.
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ in_size: Size,
145
+ spk_type: DTypeLike = bool,
146
+ name: Optional[str] = None,
147
+ ):
148
+ super().__init__(in_size=in_size, name=name)
149
+ self.spk_type = spk_type
150
+
151
+ def update(self, freqs: ArrayLike):
152
+ spikes = random.rand(*self.varshape) <= (freqs * environ.get_dt())
153
+ spikes = u.math.asarray(spikes, dtype=self.spk_type)
154
+ return spikes
@@ -13,20 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """
16
+ from __future__ import annotations
17
17
 
18
- This module defines the basic classes for synaptic projections.
18
+ from brainstate.nn._dynamics._dynamics_base import Projection
19
19
 
20
- """
20
+ __all__ = [
21
+ ]
21
22
 
22
- from ._align_post import *
23
- from ._align_post import __all__ as align_post_all
24
- from ._align_pre import *
25
- from ._align_pre import __all__ as align_pre_all
26
- from ._delta import *
27
- from ._delta import __all__ as delta_all
28
- from ._vanilla import *
29
- from ._vanilla import __all__ as vanilla_all
30
23
 
31
- __all__ = align_post_all + align_pre_all + delta_all + vanilla_all
32
- del align_post_all, align_pre_all, delta_all, vanilla_all
24
+ class ExponentialSynapse(Projection):
25
+ pass