brainstate 0.1.2__py2.py3-none-any.whl → 0.1.4__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +0 -15
  3. brainstate/compile/_jit.py +14 -5
  4. brainstate/compile/_make_jaxpr.py +78 -22
  5. brainstate/compile/_make_jaxpr_test.py +13 -2
  6. brainstate/graph/_graph_node.py +1 -1
  7. brainstate/graph/_graph_operation.py +4 -4
  8. brainstate/mixin.py +30 -14
  9. brainstate/nn/__init__.py +84 -17
  10. brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
  11. brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +19 -3
  12. brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +6 -5
  13. brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +137 -21
  14. brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
  15. brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
  16. brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob.py} +96 -25
  17. brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
  18. brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
  19. brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +2 -2
  20. brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
  21. brainstate/nn/_module.py +5 -5
  22. brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
  23. brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
  24. brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
  25. brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +1 -1
  26. brainstate/nn/_projection.py +486 -0
  27. brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
  28. brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
  29. brainstate/nn/_stp.py +236 -0
  30. brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +19 -212
  31. brainstate/nn/_synaptic_projection.py +423 -0
  32. brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
  33. brainstate/surrogate.py +1 -1
  34. brainstate/typing.py +1 -1
  35. brainstate/util/__init__.py +14 -14
  36. brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
  37. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/METADATA +1 -1
  38. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/RECORD +61 -63
  39. brainstate/nn/_dyn_impl/__init__.py +0 -42
  40. brainstate/nn/_dynamics/__init__.py +0 -37
  41. brainstate/nn/_dynamics/_projection_base.py +0 -362
  42. brainstate/nn/_elementwise/__init__.py +0 -22
  43. brainstate/nn/_interaction/__init__.py +0 -41
  44. /brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +0 -0
  45. /brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +0 -0
  46. /brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +0 -0
  47. /brainstate/nn/{_elementwise/_elementwise_test.py → _elementwise_test.py} +0 -0
  48. /brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
  49. /brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -0
  50. /brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +0 -0
  51. /brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +0 -0
  52. /brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +0 -0
  53. /brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +0 -0
  54. /brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +0 -0
  55. /brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +0 -0
  56. /brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +0 -0
  57. /brainstate/util/{_caller.py → caller.py} +0 -0
  58. /brainstate/util/{_error.py → error.py} +0 -0
  59. /brainstate/util/{_others.py → others.py} +0 -0
  60. /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
  61. /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
  62. /brainstate/util/{_scaling.py → scaling.py} +0 -0
  63. /brainstate/util/{_struct.py → struct.py} +0 -0
  64. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/LICENSE +0 -0
  65. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/WHEEL +0 -0
  66. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/top_level.txt +0 -0
brainstate/nn/_stp.py ADDED
@@ -0,0 +1,236 @@
1
+ # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from typing import Optional
19
+
20
+ import brainunit as u
21
+
22
+ from brainstate import init
23
+ from brainstate._state import HiddenState
24
+ from brainstate.typing import ArrayLike, Size
25
+ from ._exp_euler import exp_euler_step
26
+ from ._synapse import Synapse
27
+
28
+ __all__ = [
29
+ 'ShortTermPlasticity', 'STP', 'STD',
30
+ ]
31
+
32
+
33
+ class ShortTermPlasticity(Synapse):
34
+ pass
35
+
36
+
37
+ class STP(ShortTermPlasticity):
38
+ r"""
39
+ Synapse with short-term plasticity.
40
+
41
+ This class implements a synapse model with short-term plasticity (STP), which captures
42
+ activity-dependent changes in synaptic efficacy that occur over milliseconds to seconds.
43
+ The model simultaneously accounts for both short-term facilitation and depression
44
+ based on the formulation by Tsodyks & Markram (1998).
45
+
46
+ The model is characterized by the following equations:
47
+
48
+ $$
49
+ \frac{du}{dt} = -\frac{u}{\tau_f} + U \cdot (1 - u) \cdot \delta(t - t_{spike})
50
+ $$
51
+
52
+ $$
53
+ \frac{dx}{dt} = \frac{1 - x}{\tau_d} - u \cdot x \cdot \delta(t - t_{spike})
54
+ $$
55
+
56
+ $$
57
+ g_{syn} = u \cdot x
58
+ $$
59
+
60
+ where:
61
+ - $u$ represents the utilization of synaptic efficacy (facilitation variable)
62
+ - $x$ represents the available synaptic resources (depression variable)
63
+ - $\tau_f$ is the facilitation time constant
64
+ - $\tau_d$ is the depression time constant
65
+ - $U$ is the baseline utilization parameter
66
+ - $\delta(t - t_{spike})$ is the Dirac delta function representing presynaptic spikes
67
+ - $g_{syn}$ is the effective synaptic conductance
68
+
69
+ Parameters
70
+ ----------
71
+ in_size : Size
72
+ Size of the input.
73
+ name : str, optional
74
+ Name of the synapse instance.
75
+ U : ArrayLike, default=0.15
76
+ Baseline utilization parameter (fraction of resources used per action potential).
77
+ tau_f : ArrayLike, default=1500.*u.ms
78
+ Time constant of short-term facilitation in milliseconds.
79
+ tau_d : ArrayLike, default=200.*u.ms
80
+ Time constant of short-term depression (recovery of synaptic resources) in milliseconds.
81
+
82
+ Attributes
83
+ ----------
84
+ u : HiddenState
85
+ Utilization of synaptic efficacy (facilitation variable).
86
+ x : HiddenState
87
+ Available synaptic resources (depression variable).
88
+
89
+ Notes
90
+ -----
91
+ - Larger values of tau_f produce stronger facilitation effects.
92
+ - Larger values of tau_d lead to slower recovery from depression.
93
+ - The parameter U controls the initial release probability.
94
+ - The effective synaptic strength is the product of u and x.
95
+
96
+ References
97
+ ----------
98
+ .. [1] Tsodyks, M. V., & Markram, H. (1997). The neural code between neocortical
99
+ pyramidal neurons depends on neurotransmitter release probability.
100
+ Proceedings of the National Academy of Sciences, 94(2), 719-723.
101
+ .. [2] Tsodyks, M., Pawelzik, K., & Markram, H. (1998). Neural networks with dynamic
102
+ synapses. Neural computation, 10(4), 821-835.
103
+ """
104
+ __module__ = 'brainstate.nn'
105
+
106
+ def __init__(
107
+ self,
108
+ in_size: Size,
109
+ name: Optional[str] = None,
110
+ U: ArrayLike = 0.15,
111
+ tau_f: ArrayLike = 1500. * u.ms,
112
+ tau_d: ArrayLike = 200. * u.ms,
113
+ ):
114
+ super().__init__(name=name, in_size=in_size)
115
+
116
+ # parameters
117
+ self.tau_f = init.param(tau_f, self.varshape)
118
+ self.tau_d = init.param(tau_d, self.varshape)
119
+ self.U = init.param(U, self.varshape)
120
+
121
+ def init_state(self, batch_size: int = None, **kwargs):
122
+ self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
123
+ self.u = HiddenState(init.param(init.Constant(self.U), self.varshape, batch_size))
124
+
125
+ def reset_state(self, batch_size: int = None, **kwargs):
126
+ self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
127
+ self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
128
+
129
+ def update(self, pre_spike):
130
+ u = exp_euler_step(lambda u: - u / self.tau_f, self.u.value)
131
+ x = exp_euler_step(lambda x: (1 - x) / self.tau_d, self.x.value)
132
+
133
+ # --- original code:
134
+ # if pre_spike.dtype == jax.numpy.bool_:
135
+ # u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
136
+ # x = bm.where(pre_spike, x - u * self.x, x)
137
+ # else:
138
+ # u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u
139
+ # x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
140
+
141
+ # --- simplified code:
142
+ u = u + pre_spike * self.U * (1 - self.u.value)
143
+ x = x - pre_spike * u * self.x.value
144
+
145
+ self.u.value = u
146
+ self.x.value = x
147
+ return u * x * pre_spike
148
+
149
+
150
+ class STD(ShortTermPlasticity):
151
+ r"""
152
+ Synapse with short-term depression.
153
+
154
+ This class implements a synapse model with short-term depression (STD), which captures
155
+ activity-dependent reduction in synaptic efficacy, typically caused by depletion of
156
+ neurotransmitter vesicles following repeated stimulation.
157
+
158
+ The model is characterized by the following equation:
159
+
160
+ $$
161
+ \frac{dx}{dt} = \frac{1 - x}{\tau} - U \cdot x \cdot \delta(t - t_{spike})
162
+ $$
163
+
164
+ $$
165
+ g_{syn} = x
166
+ $$
167
+
168
+ where:
169
+ - $x$ represents the available synaptic resources (depression variable)
170
+ - $\tau$ is the depression recovery time constant
171
+ - $U$ is the utilization parameter (fraction of resources depleted per spike)
172
+ - $\delta(t - t_{spike})$ is the Dirac delta function representing presynaptic spikes
173
+ - $g_{syn}$ is the effective synaptic conductance
174
+
175
+ Parameters
176
+ ----------
177
+ in_size : Size
178
+ Size of the input.
179
+ name : str, optional
180
+ Name of the synapse instance.
181
+ tau : ArrayLike, default=200.*u.ms
182
+ Time constant governing recovery of synaptic resources in milliseconds.
183
+ U : ArrayLike, default=0.07
184
+ Utilization parameter (fraction of resources used per action potential).
185
+
186
+ Attributes
187
+ ----------
188
+ x : HiddenState
189
+ Available synaptic resources (depression variable).
190
+
191
+ Notes
192
+ -----
193
+ - Larger values of tau lead to slower recovery from depression.
194
+ - Larger values of U cause stronger depression with each spike.
195
+ - This model is a simplified version of the STP model that only includes depression.
196
+
197
+ References
198
+ ----------
199
+ .. [1] Abbott, L. F., Varela, J. A., Sen, K., & Nelson, S. B. (1997). Synaptic
200
+ depression and cortical gain control. Science, 275(5297), 220-224.
201
+ .. [2] Tsodyks, M. V., & Markram, H. (1997). The neural code between neocortical
202
+ pyramidal neurons depends on neurotransmitter release probability.
203
+ Proceedings of the National Academy of Sciences, 94(2), 719-723.
204
+ """
205
+ __module__ = 'brainstate.nn'
206
+
207
+ def __init__(
208
+ self,
209
+ in_size: Size,
210
+ name: Optional[str] = None,
211
+ # synapse parameters
212
+ tau: ArrayLike = 200. * u.ms,
213
+ U: ArrayLike = 0.07,
214
+ ):
215
+ super().__init__(name=name, in_size=in_size)
216
+
217
+ # parameters
218
+ self.tau = init.param(tau, self.varshape)
219
+ self.U = init.param(U, self.varshape)
220
+
221
+ def init_state(self, batch_size: int = None, **kwargs):
222
+ self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
223
+
224
+ def reset_state(self, batch_size: int = None, **kwargs):
225
+ self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
226
+
227
+ def update(self, pre_spike):
228
+ x = exp_euler_step(lambda x: (1 - x) / self.tau, self.x.value)
229
+
230
+ # --- original code:
231
+ # self.x.value = bm.where(pre_spike, x - self.U * self.x, x)
232
+
233
+ # --- simplified code:
234
+ self.x.value = x - pre_spike * self.U * self.x.value
235
+
236
+ return self.x.value * pre_spike
@@ -23,12 +23,12 @@ import brainunit as u
23
23
  from brainstate import init, environ
24
24
  from brainstate._state import ShortTermState, HiddenState
25
25
  from brainstate.mixin import AlignPost
26
- from brainstate.nn._dynamics._dynamics_base import Dynamics
27
- from brainstate.nn._exp_euler import exp_euler_step
28
- from brainstate.typing import ArrayLike, Size
26
+ from brainstate.typing import ArrayLike, Size, PyTree
27
+ from ._dynamics import Dynamics
28
+ from ._exp_euler import exp_euler_step
29
29
 
30
30
  __all__ = [
31
- 'Synapse', 'Expon', 'DualExpon', 'Alpha', 'STP', 'STD', 'AMPA', 'GABAa',
31
+ 'Synapse', 'Expon', 'DualExpon', 'Alpha', 'AMPA', 'GABAa',
32
32
  ]
33
33
 
34
34
 
@@ -123,6 +123,9 @@ class Expon(Synapse, AlignPost):
123
123
  g = exp_euler_step(lambda g: self.sum_current_inputs(-g) / self.tau, self.g.value)
124
124
  self.g.value = self.sum_delta_inputs(g)
125
125
  if x is not None: self.g.value += x
126
+ return self.update_return()
127
+
128
+ def update_return(self) -> PyTree:
126
129
  return self.g.value
127
130
 
128
131
 
@@ -229,6 +232,9 @@ class DualExpon(Synapse, AlignPost):
229
232
  if x is not None:
230
233
  self.g_rise.value += x
231
234
  self.g_decay.value += x
235
+ return self.update_return()
236
+
237
+ def update_return(self) -> PyTree:
232
238
  return self.a * (self.g_decay.value - self.g_rise.value)
233
239
 
234
240
 
@@ -304,208 +310,6 @@ class Alpha(Synapse):
304
310
  return self.g.value
305
311
 
306
312
 
307
- class STP(Synapse):
308
- r"""
309
- Synapse with short-term plasticity.
310
-
311
- This class implements a synapse model with short-term plasticity (STP), which captures
312
- activity-dependent changes in synaptic efficacy that occur over milliseconds to seconds.
313
- The model simultaneously accounts for both short-term facilitation and depression
314
- based on the formulation by Tsodyks & Markram (1998).
315
-
316
- The model is characterized by the following equations:
317
-
318
- $$
319
- \frac{du}{dt} = -\frac{u}{\tau_f} + U \cdot (1 - u) \cdot \delta(t - t_{spike})
320
- $$
321
-
322
- $$
323
- \frac{dx}{dt} = \frac{1 - x}{\tau_d} - u \cdot x \cdot \delta(t - t_{spike})
324
- $$
325
-
326
- $$
327
- g_{syn} = u \cdot x
328
- $$
329
-
330
- where:
331
- - $u$ represents the utilization of synaptic efficacy (facilitation variable)
332
- - $x$ represents the available synaptic resources (depression variable)
333
- - $\tau_f$ is the facilitation time constant
334
- - $\tau_d$ is the depression time constant
335
- - $U$ is the baseline utilization parameter
336
- - $\delta(t - t_{spike})$ is the Dirac delta function representing presynaptic spikes
337
- - $g_{syn}$ is the effective synaptic conductance
338
-
339
- Parameters
340
- ----------
341
- in_size : Size
342
- Size of the input.
343
- name : str, optional
344
- Name of the synapse instance.
345
- U : ArrayLike, default=0.15
346
- Baseline utilization parameter (fraction of resources used per action potential).
347
- tau_f : ArrayLike, default=1500.*u.ms
348
- Time constant of short-term facilitation in milliseconds.
349
- tau_d : ArrayLike, default=200.*u.ms
350
- Time constant of short-term depression (recovery of synaptic resources) in milliseconds.
351
-
352
- Attributes
353
- ----------
354
- u : HiddenState
355
- Utilization of synaptic efficacy (facilitation variable).
356
- x : HiddenState
357
- Available synaptic resources (depression variable).
358
-
359
- Notes
360
- -----
361
- - Larger values of tau_f produce stronger facilitation effects.
362
- - Larger values of tau_d lead to slower recovery from depression.
363
- - The parameter U controls the initial release probability.
364
- - The effective synaptic strength is the product of u and x.
365
-
366
- References
367
- ----------
368
- .. [1] Tsodyks, M. V., & Markram, H. (1997). The neural code between neocortical
369
- pyramidal neurons depends on neurotransmitter release probability.
370
- Proceedings of the National Academy of Sciences, 94(2), 719-723.
371
- .. [2] Tsodyks, M., Pawelzik, K., & Markram, H. (1998). Neural networks with dynamic
372
- synapses. Neural computation, 10(4), 821-835.
373
- """
374
- __module__ = 'brainstate.nn'
375
-
376
- def __init__(
377
- self,
378
- in_size: Size,
379
- name: Optional[str] = None,
380
- U: ArrayLike = 0.15,
381
- tau_f: ArrayLike = 1500. * u.ms,
382
- tau_d: ArrayLike = 200. * u.ms,
383
- ):
384
- super().__init__(name=name, in_size=in_size)
385
-
386
- # parameters
387
- self.tau_f = init.param(tau_f, self.varshape)
388
- self.tau_d = init.param(tau_d, self.varshape)
389
- self.U = init.param(U, self.varshape)
390
-
391
- def init_state(self, batch_size: int = None, **kwargs):
392
- self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
393
- self.u = HiddenState(init.param(init.Constant(self.U), self.varshape, batch_size))
394
-
395
- def reset_state(self, batch_size: int = None, **kwargs):
396
- self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
397
- self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
398
-
399
- def update(self, pre_spike):
400
- u = exp_euler_step(lambda u: - u / self.tau_f, self.u.value)
401
- x = exp_euler_step(lambda x: (1 - x) / self.tau_d, self.x.value)
402
-
403
- # --- original code:
404
- # if pre_spike.dtype == jax.numpy.bool_:
405
- # u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
406
- # x = bm.where(pre_spike, x - u * self.x, x)
407
- # else:
408
- # u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u
409
- # x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
410
-
411
- # --- simplified code:
412
- u = u + pre_spike * self.U * (1 - self.u.value)
413
- x = x - pre_spike * u * self.x.value
414
-
415
- self.u.value = u
416
- self.x.value = x
417
- return u * x * pre_spike
418
-
419
-
420
- class STD(Synapse):
421
- r"""
422
- Synapse with short-term depression.
423
-
424
- This class implements a synapse model with short-term depression (STD), which captures
425
- activity-dependent reduction in synaptic efficacy, typically caused by depletion of
426
- neurotransmitter vesicles following repeated stimulation.
427
-
428
- The model is characterized by the following equation:
429
-
430
- $$
431
- \frac{dx}{dt} = \frac{1 - x}{\tau} - U \cdot x \cdot \delta(t - t_{spike})
432
- $$
433
-
434
- $$
435
- g_{syn} = x
436
- $$
437
-
438
- where:
439
- - $x$ represents the available synaptic resources (depression variable)
440
- - $\tau$ is the depression recovery time constant
441
- - $U$ is the utilization parameter (fraction of resources depleted per spike)
442
- - $\delta(t - t_{spike})$ is the Dirac delta function representing presynaptic spikes
443
- - $g_{syn}$ is the effective synaptic conductance
444
-
445
- Parameters
446
- ----------
447
- in_size : Size
448
- Size of the input.
449
- name : str, optional
450
- Name of the synapse instance.
451
- tau : ArrayLike, default=200.*u.ms
452
- Time constant governing recovery of synaptic resources in milliseconds.
453
- U : ArrayLike, default=0.07
454
- Utilization parameter (fraction of resources used per action potential).
455
-
456
- Attributes
457
- ----------
458
- x : HiddenState
459
- Available synaptic resources (depression variable).
460
-
461
- Notes
462
- -----
463
- - Larger values of tau lead to slower recovery from depression.
464
- - Larger values of U cause stronger depression with each spike.
465
- - This model is a simplified version of the STP model that only includes depression.
466
-
467
- References
468
- ----------
469
- .. [1] Abbott, L. F., Varela, J. A., Sen, K., & Nelson, S. B. (1997). Synaptic
470
- depression and cortical gain control. Science, 275(5297), 220-224.
471
- .. [2] Tsodyks, M. V., & Markram, H. (1997). The neural code between neocortical
472
- pyramidal neurons depends on neurotransmitter release probability.
473
- Proceedings of the National Academy of Sciences, 94(2), 719-723.
474
- """
475
- __module__ = 'brainstate.nn'
476
-
477
- def __init__(
478
- self,
479
- in_size: Size,
480
- name: Optional[str] = None,
481
- # synapse parameters
482
- tau: ArrayLike = 200. * u.ms,
483
- U: ArrayLike = 0.07,
484
- ):
485
- super().__init__(name=name, in_size=in_size)
486
-
487
- # parameters
488
- self.tau = init.param(tau, self.varshape)
489
- self.U = init.param(U, self.varshape)
490
-
491
- def init_state(self, batch_size: int = None, **kwargs):
492
- self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
493
-
494
- def reset_state(self, batch_size: int = None, **kwargs):
495
- self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
496
-
497
- def update(self, pre_spike):
498
- x = exp_euler_step(lambda x: (1 - x) / self.tau, self.x.value)
499
-
500
- # --- original code:
501
- # self.x.value = bm.where(pre_spike, x - self.U * self.x, x)
502
-
503
- # --- simplified code:
504
- self.x.value = x - pre_spike * self.U * self.x.value
505
-
506
- return self.x.value * pre_spike
507
-
508
-
509
313
  class AMPA(Synapse):
510
314
  r"""AMPA receptor synapse model.
511
315
 
@@ -587,7 +391,7 @@ class AMPA(Synapse):
587
391
  beta: ArrayLike = 0.18 / u.ms,
588
392
  T: ArrayLike = 0.5 * u.mM,
589
393
  T_dur: ArrayLike = 0.5 * u.ms,
590
- g_initializer: ArrayLike | Callable = init.ZeroInit(),
394
+ g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
591
395
  ):
592
396
  super().__init__(name=name, in_size=in_size)
593
397
 
@@ -606,14 +410,16 @@ class AMPA(Synapse):
606
410
  self.g.value = init.param(self.g_initializer, self.varshape, batch_or_mode)
607
411
  self.spike_arrival_time.value = init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_or_mode)
608
412
 
609
- def dg(self, g, t, TT):
610
- return self.alpha * TT * (1 - g) - self.beta * g
611
-
612
413
  def update(self, pre_spike):
613
414
  t = environ.get('t')
614
415
  self.spike_arrival_time.value = u.math.where(pre_spike, t, self.spike_arrival_time.value)
615
416
  TT = ((t - self.spike_arrival_time.value) < self.T_duration) * self.T
616
- self.g.value = exp_euler_step(self.dg, self.g.value, t, TT)
417
+ dg = lambda g: self.alpha * TT * (1 - g) - self.beta * g
418
+ self.g.value = exp_euler_step(dg, self.g.value)
419
+ return self.update_return()
420
+
421
+ def update_return(self) -> PyTree:
422
+ """Return the synaptic conductance value."""
617
423
  return self.g.value
618
424
 
619
425
 
@@ -696,7 +502,7 @@ class GABAa(AMPA):
696
502
  beta: ArrayLike = 0.18 / u.ms,
697
503
  T: ArrayLike = 1.0 * u.mM,
698
504
  T_dur: ArrayLike = 1.0 * u.ms,
699
- g_initializer: ArrayLike | Callable = init.ZeroInit(),
505
+ g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
700
506
  ):
701
507
  super().__init__(
702
508
  alpha=alpha,
@@ -707,3 +513,4 @@ class GABAa(AMPA):
707
513
  in_size=in_size,
708
514
  g_initializer=g_initializer
709
515
  )
516
+