brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_synapse.py DELETED
@@ -1,505 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
-
19
- from typing import Optional, Callable
20
-
21
- import brainunit as u
22
-
23
- from brainstate import init, environ
24
- from brainstate._state import ShortTermState, HiddenState
25
- from brainstate.mixin import AlignPost
26
- from brainstate.typing import ArrayLike, Size
27
- from ._dynamics import Dynamics
28
- from ._exp_euler import exp_euler_step
29
-
30
- __all__ = [
31
- 'Synapse', 'Expon', 'DualExpon', 'Alpha', 'AMPA', 'GABAa',
32
- ]
33
-
34
-
35
- class Synapse(Dynamics):
36
- """
37
- Base class for synapse dynamics.
38
-
39
- This class serves as the foundation for all synapse models in the library,
40
- providing a common interface for implementing various types of synaptic
41
- connectivity and transmission mechanisms.
42
-
43
- Synapses are responsible for modeling the transmission of signals between
44
- neurons, including temporal dynamics, plasticity, and neurotransmitter effects.
45
- All specific synapse implementations (like Expon, Alpha, AMPA, etc.) should
46
- inherit from this class.
47
-
48
- See Also
49
- --------
50
- Expon : Simple first-order exponential decay synapse model
51
- Alpha : Alpha function synapse model
52
- DualExpon : Dual exponential synapse model
53
- STP : Synapse with short-term plasticity
54
- STD : Synapse with short-term depression
55
- AMPA : AMPA receptor synapse model
56
- GABAa : GABAa receptor synapse model
57
- """
58
- __module__ = 'brainstate.nn'
59
-
60
-
61
- class Expon(Synapse, AlignPost):
62
- r"""
63
- Exponential decay synapse model.
64
-
65
- This class implements a simple first-order exponential decay synapse model where
66
- the synaptic conductance g decays exponentially with time constant tau:
67
-
68
- $$
69
- dg/dt = -g/\tau + \text{input}
70
- $$
71
-
72
- The model is widely used for basic synaptic transmission modeling.
73
-
74
- Parameters
75
- ----------
76
- in_size : Size
77
- Size of the input.
78
- name : str, optional
79
- Name of the synapse instance.
80
- tau : ArrayLike, default=8.0*u.ms
81
- Time constant of decay in milliseconds.
82
- g_initializer : ArrayLike or Callable, default=init.ZeroInit(unit=u.mS)
83
- Initial value or initializer for synaptic conductance.
84
-
85
- Attributes
86
- ----------
87
- g : HiddenState
88
- Synaptic conductance state variable.
89
- tau : Parameter
90
- Time constant of decay.
91
-
92
- Notes
93
- -----
94
- The implementation uses an exponential Euler integration method.
95
- The output of this synapse is the conductance value.
96
-
97
- This class inherits from :py:class:`AlignPost`, which means it can be used in projection patterns
98
- where synaptic variables are aligned with post-synaptic neurons, enabling event-driven
99
- computation and more efficient handling of sparse connectivity patterns.
100
- """
101
- __module__ = 'brainstate.nn'
102
-
103
- def __init__(
104
- self,
105
- in_size: Size,
106
- name: Optional[str] = None,
107
- tau: ArrayLike = 8.0 * u.ms,
108
- g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
109
- ):
110
- super().__init__(name=name, in_size=in_size)
111
-
112
- # parameters
113
- self.tau = init.param(tau, self.varshape)
114
- self.g_initializer = g_initializer
115
-
116
- def init_state(self, batch_size: int = None, **kwargs):
117
- self.g = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
118
-
119
- def reset_state(self, batch_size: int = None, **kwargs):
120
- self.g.value = init.param(self.g_initializer, self.varshape, batch_size)
121
-
122
- def update(self, x=None):
123
- g = exp_euler_step(lambda g: self.sum_current_inputs(-g) / self.tau, self.g.value)
124
- self.g.value = self.sum_delta_inputs(g)
125
- if x is not None: self.g.value += x
126
- return self.g.value
127
-
128
-
129
- class DualExpon(Synapse, AlignPost):
130
- r"""
131
- Dual exponential synapse model.
132
-
133
- This class implements a synapse model with separate rise and decay time constants,
134
- which produces a more biologically realistic conductance waveform than a single
135
- exponential model. The model is characterized by the differential equation system:
136
-
137
- dg_rise/dt = -g_rise/tau_rise
138
- dg_decay/dt = -g_decay/tau_decay
139
- g = a * (g_decay - g_rise)
140
-
141
- where $a$ is a normalization factor that ensures the peak conductance reaches
142
- the desired amplitude.
143
-
144
- Parameters
145
- ----------
146
- in_size : Size
147
- Size of the input.
148
- name : str, optional
149
- Name of the synapse instance.
150
- tau_decay : ArrayLike, default=10.0*u.ms
151
- Time constant of decay in milliseconds.
152
- tau_rise : ArrayLike, default=1.0*u.ms
153
- Time constant of rise in milliseconds.
154
- A : ArrayLike, optional
155
- Amplitude scaling factor. If None, a scaling factor is automatically
156
- calculated to normalize the peak amplitude.
157
- g_initializer : ArrayLike or Callable, default=init.ZeroInit(unit=u.mS)
158
- Initial value or initializer for synaptic conductance.
159
-
160
- Attributes
161
- ----------
162
- g_rise : HiddenState
163
- Rise component of synaptic conductance.
164
- g_decay : HiddenState
165
- Decay component of synaptic conductance.
166
- tau_rise : Parameter
167
- Time constant of rise phase.
168
- tau_decay : Parameter
169
- Time constant of decay phase.
170
- a : Parameter
171
- Normalization factor calculated from tau_rise, tau_decay, and A.
172
-
173
- Notes
174
- -----
175
- The dual exponential model produces a conductance waveform that is more
176
- physiologically realistic than a simple exponential decay, with a finite
177
- rise time followed by a slower decay.
178
-
179
- The implementation uses an exponential Euler integration method.
180
- The output of this synapse is the normalized difference between decay and rise components.
181
-
182
- This class inherits from :py:class:`AlignPost`, which means it can be used in projection patterns
183
- where synaptic variables are aligned with post-synaptic neurons, enabling event-driven
184
- computation and more efficient handling of sparse connectivity patterns.
185
- """
186
- __module__ = 'brainstate.nn'
187
-
188
- def __init__(
189
- self,
190
- in_size: Size,
191
- name: Optional[str] = None,
192
- tau_decay: ArrayLike = 10.0 * u.ms,
193
- tau_rise: ArrayLike = 1.0 * u.ms,
194
- A: Optional[ArrayLike] = None,
195
- g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
196
- ):
197
- super().__init__(name=name, in_size=in_size)
198
-
199
- # parameters
200
- self.tau_decay = init.param(tau_decay, self.varshape)
201
- self.tau_rise = init.param(tau_rise, self.varshape)
202
- A = self._format_dual_exp_A(A)
203
- self.a = (self.tau_decay - self.tau_rise) / self.tau_rise / self.tau_decay * A
204
- self.g_initializer = g_initializer
205
-
206
- def _format_dual_exp_A(self, A):
207
- A = init.param(A, sizes=self.varshape, allow_none=True)
208
- if A is None:
209
- A = (
210
- self.tau_decay / (self.tau_decay - self.tau_rise) *
211
- u.math.float_power(self.tau_rise / self.tau_decay,
212
- self.tau_rise / (self.tau_rise - self.tau_decay))
213
- )
214
- return A
215
-
216
- def init_state(self, batch_size: int = None, **kwargs):
217
- self.g_rise = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
218
- self.g_decay = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
219
-
220
- def reset_state(self, batch_size: int = None, **kwargs):
221
- self.g_rise.value = init.param(self.g_initializer, self.varshape, batch_size)
222
- self.g_decay.value = init.param(self.g_initializer, self.varshape, batch_size)
223
-
224
- def update(self, x=None):
225
- g_rise = exp_euler_step(lambda h: -h / self.tau_rise, self.g_rise.value)
226
- g_decay = exp_euler_step(lambda g: -g / self.tau_decay, self.g_decay.value)
227
- self.g_rise.value = self.sum_delta_inputs(g_rise)
228
- self.g_decay.value = self.sum_delta_inputs(g_decay)
229
- if x is not None:
230
- self.g_rise.value += x
231
- self.g_decay.value += x
232
- return self.a * (self.g_decay.value - self.g_rise.value)
233
-
234
-
235
- class Alpha(Synapse):
236
- r"""
237
- Alpha synapse model.
238
-
239
- This class implements the alpha function synapse model, which produces
240
- a smooth, biologically realistic synaptic conductance waveform.
241
- The model is characterized by the differential equation system:
242
-
243
- dh/dt = -h/tau
244
- dg/dt = -g/tau + h/tau
245
-
246
- This produces a response that rises and then falls with a characteristic
247
- time constant $\tau$, with peak amplitude occurring at time $t = \tau$.
248
-
249
- Parameters
250
- ----------
251
- in_size : Size
252
- Size of the input.
253
- name : str, optional
254
- Name of the synapse instance.
255
- tau : ArrayLike, default=8.0*u.ms
256
- Time constant of the alpha function in milliseconds.
257
- g_initializer : ArrayLike or Callable, default=init.ZeroInit(unit=u.mS)
258
- Initial value or initializer for synaptic conductance.
259
-
260
- Attributes
261
- ----------
262
- g : HiddenState
263
- Synaptic conductance state variable.
264
- h : HiddenState
265
- Auxiliary state variable for implementing the alpha function.
266
- tau : Parameter
267
- Time constant of the alpha function.
268
-
269
- Notes
270
- -----
271
- The alpha function is defined as g(t) = (t/tau) * exp(1-t/tau) for t ≥ 0.
272
- This implementation uses an exponential Euler integration method.
273
- The output of this synapse is the conductance value.
274
- """
275
- __module__ = 'brainstate.nn'
276
-
277
- def __init__(
278
- self,
279
- in_size: Size,
280
- name: Optional[str] = None,
281
- tau: ArrayLike = 8.0 * u.ms,
282
- g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
283
- ):
284
- super().__init__(name=name, in_size=in_size)
285
-
286
- # parameters
287
- self.tau = init.param(tau, self.varshape)
288
- self.g_initializer = g_initializer
289
-
290
- def init_state(self, batch_size: int = None, **kwargs):
291
- self.g = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
292
- self.h = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
293
-
294
- def reset_state(self, batch_size: int = None, **kwargs):
295
- self.g.value = init.param(self.g_initializer, self.varshape, batch_size)
296
- self.h.value = init.param(self.g_initializer, self.varshape, batch_size)
297
-
298
- def update(self, x=None):
299
- h = exp_euler_step(lambda h: -h / self.tau, self.h.value)
300
- self.g.value = exp_euler_step(lambda g, h: -g / self.tau + h / self.tau, self.g.value, self.h.value)
301
- self.h.value = self.sum_delta_inputs(h)
302
- if x is not None:
303
- self.h.value += x
304
- return self.g.value
305
-
306
-
307
- class AMPA(Synapse):
308
- r"""AMPA receptor synapse model.
309
-
310
- This class implements a kinetic model of AMPA (α-amino-3-hydroxy-5-methyl-4-isoxazolepropionic acid)
311
- receptor-mediated synaptic transmission. AMPA receptors are ionotropic glutamate receptors that mediate
312
- fast excitatory synaptic transmission in the central nervous system.
313
-
314
- The model uses a Markov process approach to describe the state transitions of AMPA receptors
315
- between closed and open states, governed by neurotransmitter binding:
316
-
317
- $$
318
- \frac{dg}{dt} = \alpha [T] (1-g) - \beta g
319
- $$
320
-
321
- $$
322
- I_{syn} = g_{max} \cdot g \cdot (V - E)
323
- $$
324
-
325
- where:
326
- - $g$ represents the fraction of receptors in the open state
327
- - $\alpha$ is the binding rate constant [ms^-1 mM^-1]
328
- - $\beta$ is the unbinding rate constant [ms^-1]
329
- - $[T]$ is the neurotransmitter concentration [mM]
330
- - $I_{syn}$ is the resulting synaptic current
331
- - $g_{max}$ is the maximum conductance
332
- - $V$ is the membrane potential
333
- - $E$ is the reversal potential
334
-
335
- The neurotransmitter concentration $[T]$ follows a square pulse of amplitude T and
336
- duration T_dur after each presynaptic spike.
337
-
338
- Parameters
339
- ----------
340
- in_size : Size
341
- Size of the input.
342
- name : str, optional
343
- Name of the synapse instance.
344
- alpha : ArrayLike, default=0.98/(u.ms*u.mM)
345
- Binding rate constant [ms^-1 mM^-1].
346
- beta : ArrayLike, default=0.18/u.ms
347
- Unbinding rate constant [ms^-1].
348
- T : ArrayLike, default=0.5*u.mM
349
- Peak neurotransmitter concentration when released [mM].
350
- T_dur : ArrayLike, default=0.5*u.ms
351
- Duration of neurotransmitter presence in the synaptic cleft [ms].
352
- g_initializer : ArrayLike or Callable, default=init.ZeroInit()
353
- Initial value or initializer for the synaptic conductance.
354
-
355
- Attributes
356
- ----------
357
- g : HiddenState
358
- Fraction of receptors in the open state.
359
- spike_arrival_time : ShortTermState
360
- Time of the most recent presynaptic spike.
361
-
362
- Notes
363
- -----
364
- - The model captures the fast-rising and relatively fast-decaying excitatory currents
365
- characteristic of AMPA receptor-mediated transmission.
366
- - The time course of the synaptic conductance is determined by both the binding and
367
- unbinding rate constants and the duration of transmitter presence.
368
- - This implementation uses an exponential Euler integration method.
369
-
370
- References
371
- ----------
372
- .. [1] Destexhe, A., Mainen, Z. F., & Sejnowski, T. J. (1994). Synthesis of models for
373
- excitable membranes, synaptic transmission and neuromodulation using a common
374
- kinetic formalism. Journal of computational neuroscience, 1(3), 195-230.
375
- .. [2] Vijayan, S., & Kopell, N. J. (2012). Thalamic model of awake alpha oscillations
376
- and implications for stimulus processing. Proceedings of the National Academy
377
- of Sciences, 109(45), 18553-18558.
378
- """
379
-
380
- def __init__(
381
- self,
382
- in_size: Size,
383
- name: Optional[str] = None,
384
- alpha: ArrayLike = 0.98 / (u.ms * u.mM),
385
- beta: ArrayLike = 0.18 / u.ms,
386
- T: ArrayLike = 0.5 * u.mM,
387
- T_dur: ArrayLike = 0.5 * u.ms,
388
- g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
389
- ):
390
- super().__init__(name=name, in_size=in_size)
391
-
392
- # parameters
393
- self.alpha = init.param(alpha, self.varshape)
394
- self.beta = init.param(beta, self.varshape)
395
- self.T = init.param(T, self.varshape)
396
- self.T_duration = init.param(T_dur, self.varshape)
397
- self.g_initializer = g_initializer
398
-
399
- def init_state(self, batch_size=None):
400
- self.g = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
401
- self.spike_arrival_time = ShortTermState(init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_size))
402
-
403
- def reset_state(self, batch_or_mode=None, **kwargs):
404
- self.g.value = init.param(self.g_initializer, self.varshape, batch_or_mode)
405
- self.spike_arrival_time.value = init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_or_mode)
406
-
407
- def update(self, pre_spike):
408
- t = environ.get('t')
409
- self.spike_arrival_time.value = u.math.where(pre_spike, t, self.spike_arrival_time.value)
410
- TT = ((t - self.spike_arrival_time.value) < self.T_duration) * self.T
411
- dg = lambda g: self.alpha * TT * (1 * u.get_unit(g) - g) - self.beta * g
412
- self.g.value = exp_euler_step(dg, self.g.value)
413
- return self.g.value
414
-
415
-
416
- class GABAa(AMPA):
417
- r"""GABAa receptor synapse model.
418
-
419
- This class implements a kinetic model of GABAa (gamma-aminobutyric acid type A)
420
- receptor-mediated synaptic transmission. GABAa receptors are ionotropic chloride channels
421
- that mediate fast inhibitory synaptic transmission in the central nervous system.
422
-
423
- The model uses the same Markov process approach as the AMPA model but with different
424
- kinetic parameters appropriate for GABAa receptors:
425
-
426
- $$
427
- \frac{dg}{dt} = \alpha [T] (1-g) - \beta g
428
- $$
429
-
430
- $$
431
- I_{syn} = - g_{max} \cdot g \cdot (V - E)
432
- $$
433
-
434
- where:
435
- - $g$ represents the fraction of receptors in the open state
436
- - $\alpha$ is the binding rate constant [ms^-1 mM^-1], typically slower than AMPA
437
- - $\beta$ is the unbinding rate constant [ms^-1]
438
- - $[T]$ is the neurotransmitter (GABA) concentration [mM]
439
- - $I_{syn}$ is the resulting synaptic current (note the negative sign indicating inhibition)
440
- - $g_{max}$ is the maximum conductance
441
- - $V$ is the membrane potential
442
- - $E$ is the reversal potential (typically around -80 mV for chloride)
443
-
444
- The neurotransmitter concentration $[T]$ follows a square pulse of amplitude T and
445
- duration T_dur after each presynaptic spike.
446
-
447
- Parameters
448
- ----------
449
- in_size : Size
450
- Size of the input.
451
- name : str, optional
452
- Name of the synapse instance.
453
- alpha : ArrayLike, default=0.53/(u.ms*u.mM)
454
- Binding rate constant [ms^-1 mM^-1]. Typically slower than AMPA receptors.
455
- beta : ArrayLike, default=0.18/u.ms
456
- Unbinding rate constant [ms^-1].
457
- T : ArrayLike, default=1.0*u.mM
458
- Peak neurotransmitter concentration when released [mM]. Higher than AMPA.
459
- T_dur : ArrayLike, default=1.0*u.ms
460
- Duration of neurotransmitter presence in the synaptic cleft [ms]. Longer than AMPA.
461
- g_initializer : ArrayLike or Callable, default=init.ZeroInit()
462
- Initial value or initializer for the synaptic conductance.
463
-
464
- Attributes
465
- ----------
466
- Inherits all attributes from AMPA class.
467
-
468
- Notes
469
- -----
470
- - GABAa receptors typically produce slower-rising and longer-lasting currents compared to AMPA receptors.
471
- - The inhibitory nature of GABAa receptors is reflected in the convention of using a negative sign in the
472
- synaptic current equation.
473
- - The reversal potential for GABAa receptors is typically around -80 mV (due to chloride), making them
474
- inhibitory for neurons with resting potentials more positive than this value.
475
- - This model does not include desensitization, which can be significant for prolonged GABA exposure.
476
-
477
- References
478
- ----------
479
- .. [1] Destexhe, A., Mainen, Z. F., & Sejnowski, T. J. (1994). Synthesis of models for
480
- excitable membranes, synaptic transmission and neuromodulation using a common
481
- kinetic formalism. Journal of computational neuroscience, 1(3), 195-230.
482
- .. [2] Destexhe, A., & Paré, D. (1999). Impact of network activity on the integrative
483
- properties of neocortical pyramidal neurons in vivo. Journal of neurophysiology,
484
- 81(4), 1531-1547.
485
- """
486
-
487
- def __init__(
488
- self,
489
- in_size: Size,
490
- name: Optional[str] = None,
491
- alpha: ArrayLike = 0.53 / (u.ms * u.mM),
492
- beta: ArrayLike = 0.18 / u.ms,
493
- T: ArrayLike = 1.0 * u.mM,
494
- T_dur: ArrayLike = 1.0 * u.ms,
495
- g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
496
- ):
497
- super().__init__(
498
- alpha=alpha,
499
- beta=beta,
500
- T=T,
501
- T_dur=T_dur,
502
- name=name,
503
- in_size=in_size,
504
- g_initializer=g_initializer
505
- )
@@ -1,131 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import unittest
18
-
19
- import brainunit as u
20
- import jax.numpy as jnp
21
- import pytest
22
-
23
- import brainstate
24
- from brainstate.nn import Expon, STP, STD
25
-
26
-
27
- class TestSynapse(unittest.TestCase):
28
- def setUp(self):
29
- self.in_size = 10
30
- self.batch_size = 5
31
- self.time_steps = 100
32
-
33
- def generate_input(self):
34
- return brainstate.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mS
35
-
36
- def test_expon_synapse(self):
37
- tau = 20.0 * u.ms
38
- synapse = Expon(self.in_size, tau=tau)
39
- inputs = self.generate_input()
40
-
41
- # Test initialization
42
- self.assertEqual(synapse.in_size, (self.in_size,))
43
- self.assertEqual(synapse.out_size, (self.in_size,))
44
- self.assertEqual(synapse.tau, tau)
45
-
46
- # Test forward pass
47
- state = synapse.init_state(self.batch_size)
48
- call = brainstate.compile.jit(synapse)
49
- with brainstate.environ.context(dt=0.1 * u.ms):
50
- for t in range(self.time_steps):
51
- out = call(inputs[t])
52
- self.assertEqual(out.shape, (self.batch_size, self.in_size))
53
-
54
- # Test exponential decay
55
- constant_input = jnp.ones((self.batch_size, self.in_size)) * u.mS
56
- out1 = call(constant_input)
57
- out2 = call(constant_input)
58
- self.assertTrue(jnp.all(out2 > out1)) # Output should increase with constant input
59
-
60
- @pytest.mark.skip(reason="Not implemented yet")
61
- def test_stp_synapse(self):
62
- tau_d = 200.0 * u.ms
63
- tau_f = 20.0 * u.ms
64
- U = 0.2
65
- synapse = STP(self.in_size, tau_d=tau_d, tau_f=tau_f, U=U)
66
- inputs = self.generate_input()
67
-
68
- # Test initialization
69
- self.assertEqual(synapse.in_size, (self.in_size,))
70
- self.assertEqual(synapse.out_size, (self.in_size,))
71
- self.assertEqual(synapse.tau_d, tau_d)
72
- self.assertEqual(synapse.tau_f, tau_f)
73
- self.assertEqual(synapse.U, U)
74
-
75
- # Test forward pass
76
- state = synapse.init_state(self.batch_size)
77
- call = brainstate.compile.jit(synapse)
78
- for t in range(self.time_steps):
79
- out = call(inputs[t])
80
- self.assertEqual(out.shape, (self.batch_size, self.in_size))
81
-
82
- # Test short-term plasticity
83
- constant_input = jnp.ones((self.batch_size, self.in_size)) * u.mS
84
- out1 = call(constant_input)
85
- out2 = call(constant_input)
86
- self.assertTrue(jnp.any(out2 != out1)) # Output should change due to STP
87
-
88
- @pytest.mark.skip(reason="Not implemented yet")
89
- def test_std_synapse(self):
90
- tau = 200.0
91
- U = 0.2
92
- synapse = STD(self.in_size, tau=tau, U=U)
93
- inputs = self.generate_input()
94
-
95
- # Test initialization
96
- self.assertEqual(synapse.in_size, (self.in_size,))
97
- self.assertEqual(synapse.out_size, (self.in_size,))
98
- self.assertEqual(synapse.tau, tau)
99
- self.assertEqual(synapse.U, U)
100
-
101
- # Test forward pass
102
- state = synapse.init_state(self.batch_size)
103
- for t in range(self.time_steps):
104
- out = synapse(inputs[t])
105
- self.assertEqual(out.shape, (self.batch_size, self.in_size))
106
-
107
- # Test short-term depression
108
- constant_input = jnp.ones((self.batch_size, self.in_size))
109
- out1 = synapse(constant_input)
110
- out2 = synapse(constant_input)
111
- self.assertTrue(jnp.all(out2 < out1)) # Output should decrease due to STD
112
-
113
- def test_keep_size(self):
114
- in_size = (2, 3)
115
- for SynapseClass in [Expon, ]:
116
- synapse = SynapseClass(in_size)
117
- self.assertEqual(synapse.in_size, in_size)
118
- self.assertEqual(synapse.out_size, in_size)
119
-
120
- inputs = brainstate.random.randn(self.time_steps, self.batch_size, *in_size) * u.mS
121
- state = synapse.init_state(self.batch_size)
122
- call = brainstate.compile.jit(synapse)
123
- with brainstate.environ.context(dt=0.1 * u.ms):
124
- for t in range(self.time_steps):
125
- out = call(inputs[t])
126
- self.assertEqual(out.shape, (self.batch_size, *in_size))
127
-
128
-
129
- if __name__ == '__main__':
130
- with brainstate.environ.context(dt=0.1):
131
- unittest.main()