brainstate 0.1.3__py2.py3-none-any.whl → 0.1.5__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +1 -16
- brainstate/_state.py +1 -0
- brainstate/augment/_mapping.py +9 -9
- brainstate/augment/_mapping_test.py +162 -0
- brainstate/compile/_jit.py +14 -5
- brainstate/compile/_make_jaxpr.py +78 -22
- brainstate/compile/_make_jaxpr_test.py +13 -2
- brainstate/graph/_graph_node.py +1 -1
- brainstate/graph/_graph_operation.py +4 -4
- brainstate/mixin.py +31 -2
- brainstate/nn/__init__.py +8 -5
- brainstate/nn/_common.py +7 -19
- brainstate/nn/_delay.py +13 -1
- brainstate/nn/_dropout.py +5 -4
- brainstate/nn/_dynamics.py +39 -44
- brainstate/nn/_exp_euler.py +13 -16
- brainstate/nn/{_fixedprob_mv.py → _fixedprob.py} +95 -24
- brainstate/nn/_inputs.py +1 -1
- brainstate/nn/_linear_mv.py +1 -1
- brainstate/nn/_module.py +5 -5
- brainstate/nn/_projection.py +190 -98
- brainstate/nn/_synapse.py +5 -9
- brainstate/nn/_synaptic_projection.py +376 -86
- brainstate/random/_rand_state.py +13 -7
- brainstate/surrogate.py +1 -1
- brainstate/typing.py +1 -1
- brainstate/util/__init__.py +14 -14
- brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/METADATA +1 -1
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/RECORD +42 -42
- /brainstate/nn/{_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
- /brainstate/util/{_caller.py → caller.py} +0 -0
- /brainstate/util/{_error.py → error.py} +0 -0
- /brainstate/util/{_others.py → others.py} +0 -0
- /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
- /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
- /brainstate/util/{_scaling.py → scaling.py} +0 -0
- /brainstate/util/{_struct.py → struct.py} +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/LICENSE +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/WHEEL +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/top_level.txt +0 -0
brainstate/nn/_linear_mv.py
CHANGED
@@ -19,7 +19,7 @@ import brainunit as u
|
|
19
19
|
import jax
|
20
20
|
|
21
21
|
from brainstate import init
|
22
|
-
|
22
|
+
import brainevent
|
23
23
|
from brainstate._state import ParamState
|
24
24
|
from brainstate.typing import Size, ArrayLike
|
25
25
|
from ._module import Module
|
brainstate/nn/_module.py
CHANGED
@@ -34,7 +34,7 @@ import numpy as np
|
|
34
34
|
from brainstate._state import State
|
35
35
|
from brainstate.graph import Node, states, nodes, flatten
|
36
36
|
from brainstate.mixin import ParamDescriber, ParamDesc
|
37
|
-
from brainstate.typing import PathParts
|
37
|
+
from brainstate.typing import PathParts, Size
|
38
38
|
from brainstate.util import FlattedDict, NestedDict, BrainStateError
|
39
39
|
|
40
40
|
# maximum integer
|
@@ -62,8 +62,8 @@ class Module(Node, ParamDesc):
|
|
62
62
|
|
63
63
|
__module__ = 'brainstate.nn'
|
64
64
|
|
65
|
-
_in_size: Optional[
|
66
|
-
_out_size: Optional[
|
65
|
+
_in_size: Optional[Size]
|
66
|
+
_out_size: Optional[Size]
|
67
67
|
_name: Optional[str]
|
68
68
|
|
69
69
|
if not TYPE_CHECKING:
|
@@ -87,7 +87,7 @@ class Module(Node, ParamDesc):
|
|
87
87
|
raise AttributeError('The name of the model is read-only.')
|
88
88
|
|
89
89
|
@property
|
90
|
-
def in_size(self) ->
|
90
|
+
def in_size(self) -> Size:
|
91
91
|
return self._in_size
|
92
92
|
|
93
93
|
@in_size.setter
|
@@ -98,7 +98,7 @@ class Module(Node, ParamDesc):
|
|
98
98
|
self._in_size = tuple(in_size)
|
99
99
|
|
100
100
|
@property
|
101
|
-
def out_size(self) ->
|
101
|
+
def out_size(self) -> Size:
|
102
102
|
return self._out_size
|
103
103
|
|
104
104
|
@out_size.setter
|
brainstate/nn/_projection.py
CHANGED
@@ -13,25 +13,31 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from typing import
|
16
|
+
from typing import Callable, Union
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
import brainevent
|
20
|
+
import brainunit as u
|
17
21
|
|
18
22
|
from brainstate._state import State
|
19
|
-
from brainstate.mixin import
|
20
|
-
from brainstate.
|
23
|
+
from brainstate.mixin import BindCondData, JointTypes
|
24
|
+
from brainstate.mixin import ParamDescriber, AlignPost
|
25
|
+
from brainstate.util.others import get_unique_name
|
21
26
|
from ._collective_ops import call_order
|
22
|
-
from ._dynamics import Dynamics, maybe_init_prefetch, Prefetch, PrefetchDelayAt
|
27
|
+
from ._dynamics import Dynamics, Projection, maybe_init_prefetch, Prefetch, PrefetchDelayAt
|
23
28
|
from ._module import Module
|
29
|
+
from ._stp import ShortTermPlasticity
|
30
|
+
from ._synapse import Synapse
|
24
31
|
from ._synouts import SynOut
|
25
32
|
|
26
33
|
__all__ = [
|
27
34
|
'AlignPostProj',
|
28
35
|
'DeltaProj',
|
29
36
|
'CurrentProj',
|
30
|
-
]
|
31
|
-
|
32
37
|
|
33
|
-
|
34
|
-
|
38
|
+
'align_pre_projection',
|
39
|
+
'align_post_projection',
|
40
|
+
]
|
35
41
|
|
36
42
|
|
37
43
|
def _check_modules(*modules):
|
@@ -101,7 +107,7 @@ class _AlignPost(Module):
|
|
101
107
|
self.out.bind_cond(self.syn(*args, **kwargs))
|
102
108
|
|
103
109
|
|
104
|
-
class AlignPostProj(
|
110
|
+
class AlignPostProj(Projection):
|
105
111
|
"""
|
106
112
|
Align-post projection of the neural network.
|
107
113
|
|
@@ -113,7 +119,6 @@ class AlignPostProj(Interaction):
|
|
113
119
|
Note that this projection needs the manual input of pre-synaptic spikes.
|
114
120
|
|
115
121
|
>>> import brainstate
|
116
|
-
>>> import brainevent
|
117
122
|
>>> import brainunit as u
|
118
123
|
>>> n_exc = 3200
|
119
124
|
>>> n_inh = 800
|
@@ -124,16 +129,14 @@ class AlignPostProj(Interaction):
|
|
124
129
|
... tau=20. * u.ms, tau_ref=5. * u.ms,
|
125
130
|
... V_initializer=brainstate.init.Normal(-55., 2., unit=u.mV)
|
126
131
|
... )
|
127
|
-
>>> pop.
|
132
|
+
>>> pop.init_state()
|
128
133
|
>>> E = brainstate.nn.AlignPostProj(
|
129
|
-
... comm=
|
134
|
+
... comm=brainstate.nn.FixedNumConn(n_exc, num, prob=80 / num, weight=1.62 * u.mS),
|
130
135
|
... syn=brainstate.nn.Expon.desc(num, tau=5. * u.ms),
|
131
136
|
... out=brainstate.nn.CUBA.desc(scale=u.volt),
|
132
137
|
... post=pop
|
133
138
|
... )
|
134
|
-
>>> exe_current = E(pop.
|
135
|
-
|
136
|
-
|
139
|
+
>>> exe_current = E(pop.get_spike())
|
137
140
|
|
138
141
|
"""
|
139
142
|
__module__ = 'brainstate.nn'
|
@@ -226,48 +229,47 @@ class AlignPostProj(Interaction):
|
|
226
229
|
self.out.bind_cond(conductance)
|
227
230
|
|
228
231
|
|
229
|
-
class DeltaProj(
|
230
|
-
r"""Full-chain of the synaptic projection for the Delta synapse model.
|
231
|
-
|
232
|
-
The synaptic projection requires the input is the spiking data, otherwise
|
233
|
-
the synapse is not the Delta synapse model.
|
234
|
-
|
235
|
-
The ``full-chain`` means that the model needs to provide all information needed for a projection,
|
236
|
-
including ``pre`` -> ``delay`` -> ``comm`` -> ``post``.
|
237
|
-
|
238
|
-
**Model Descriptions**
|
239
|
-
|
240
|
-
.. math::
|
241
|
-
|
242
|
-
I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \delta(t-t_j-D)
|
243
|
-
|
244
|
-
where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength,
|
245
|
-
:math:`t_j` the spiking moment of the presynaptic neuron :math:`j`,
|
246
|
-
:math:`C` the set of neurons connected to the post-synaptic neuron,
|
247
|
-
and :math:`D` the transmission delay of chemical synapses.
|
248
|
-
For simplicity, the rise and decay phases of post-synaptic currents are
|
249
|
-
omitted in this model.
|
250
|
-
|
251
|
-
# brainstate.nn.DeltaInteraction(
|
252
|
-
# LIF().prefetch('V'), bst.surrogate.ReluGrad(), comm, post
|
253
|
-
# )
|
254
|
-
|
255
|
-
Args:
|
256
|
-
pre: The pre-synaptic neuron group.
|
257
|
-
delay: The synaptic delay.
|
258
|
-
comm: DynamicalSystem. The synaptic communication.
|
259
|
-
post: DynamicalSystem. The post-synaptic neuron group.
|
232
|
+
class DeltaProj(Projection):
|
260
233
|
"""
|
234
|
+
Delta-based projection of the neural network.
|
235
|
+
|
236
|
+
This projection directly applies delta inputs to post-synaptic neurons without intervening
|
237
|
+
synaptic dynamics. It processes inputs through optional prefetch modules, applies a communication model,
|
238
|
+
and adds the result directly as a delta input to the post-synaptic population.
|
239
|
+
|
240
|
+
Parameters
|
241
|
+
----------
|
242
|
+
*prefetch : State or callable
|
243
|
+
Optional prefetch modules to process input before communication.
|
244
|
+
comm : callable
|
245
|
+
Communication model that determines how signals are transmitted.
|
246
|
+
post : Dynamics
|
247
|
+
Post-synaptic neural population to receive the delta inputs.
|
248
|
+
label : Optional[str], default=None
|
249
|
+
Optional label for the projection to identify it in the post-synaptic population.
|
261
250
|
|
251
|
+
Examples
|
252
|
+
--------
|
253
|
+
>>> import brainstate
|
254
|
+
>>> import brainunit as u
|
255
|
+
>>> n_neurons = 100
|
256
|
+
>>> pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
|
257
|
+
>>> pop.init_state()
|
258
|
+
>>> delta_input = brainstate.nn.DeltaProj(
|
259
|
+
... comm=lambda x: x * 10.0*u.mV,
|
260
|
+
... post=pop
|
261
|
+
... )
|
262
|
+
>>> delta_input(1.0) # Apply voltage increment directly
|
263
|
+
"""
|
262
264
|
__module__ = 'brainstate.nn'
|
263
265
|
|
264
|
-
def __init__(self, *
|
266
|
+
def __init__(self, *prefetch, comm: Callable, post: Dynamics, label=None):
|
265
267
|
super().__init__(name=get_unique_name(self.__class__.__name__))
|
266
268
|
|
267
269
|
self.label = label
|
268
270
|
|
269
271
|
# checking modules
|
270
|
-
self.
|
272
|
+
self.prefetches = _check_modules(*prefetch)
|
271
273
|
|
272
274
|
# checking communication model
|
273
275
|
if not callable(comm):
|
@@ -285,58 +287,69 @@ class DeltaProj(Interaction):
|
|
285
287
|
|
286
288
|
@call_order(2)
|
287
289
|
def init_state(self, *args, **kwargs):
|
288
|
-
for
|
289
|
-
maybe_init_prefetch(
|
290
|
+
for prefetch in self.prefetches:
|
291
|
+
maybe_init_prefetch(prefetch, *args, **kwargs)
|
290
292
|
|
291
293
|
def update(self, *x):
|
292
|
-
for module in self.
|
294
|
+
for module in self.prefetches:
|
293
295
|
x = (call_module(module, *x),)
|
294
296
|
assert len(x) == 1, f'The output of the modules should be a single value, but got {x}.'
|
295
297
|
x = self.comm(x[0])
|
296
298
|
self.post.add_delta_input(self.name, x, label=self.label)
|
297
299
|
|
298
300
|
|
299
|
-
class CurrentProj(
|
301
|
+
class CurrentProj(Projection):
|
300
302
|
"""
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
303
|
+
Current-based projection of the neural network.
|
304
|
+
|
305
|
+
This projection directly modulates post-synaptic currents without separate synaptic dynamics.
|
306
|
+
It processes inputs through optional prefetch modules, applies a communication model,
|
307
|
+
and binds the result to the output model which is then added as a current input to the post-synaptic population.
|
308
|
+
|
309
|
+
Parameters
|
310
|
+
----------
|
311
|
+
*prefetch : State or callable
|
312
|
+
Optional prefetch modules to process input before communication.
|
313
|
+
The last element must be an instance of Prefetch or PrefetchDelayAt if any are provided.
|
314
|
+
comm : callable
|
315
|
+
Communication model that determines how signals are transmitted.
|
316
|
+
out : SynOut
|
317
|
+
Output model that converts communication results to post-synaptic currents.
|
318
|
+
post : Dynamics
|
319
|
+
Post-synaptic neural population to receive the currents.
|
314
320
|
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
321
|
+
Examples
|
322
|
+
--------
|
323
|
+
>>> import brainstate
|
324
|
+
>>> import brainunit as u
|
325
|
+
>>> n_neurons = 100
|
326
|
+
>>> pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
|
327
|
+
>>> pop.init_state()
|
328
|
+
>>> current_input = brainstate.nn.CurrentProj(
|
329
|
+
... comm=lambda x: x * 0.5,
|
330
|
+
... out=brainstate.nn.CUBA(scale=1.0*u.nA),
|
331
|
+
... post=pop
|
332
|
+
... )
|
333
|
+
>>> current_input(0.2) # Apply external current
|
324
334
|
"""
|
325
335
|
__module__ = 'brainstate.nn'
|
326
336
|
|
327
337
|
def __init__(
|
328
338
|
self,
|
329
|
-
prefetch
|
339
|
+
*prefetch,
|
330
340
|
comm: Callable,
|
331
341
|
out: SynOut,
|
332
342
|
post: Dynamics,
|
333
343
|
):
|
334
344
|
super().__init__(name=get_unique_name(self.__class__.__name__))
|
335
345
|
|
336
|
-
#
|
337
|
-
if not isinstance(prefetch, (Prefetch, PrefetchDelayAt)):
|
338
|
-
raise TypeError(f'The pre should be a Prefetch or PrefetchDelayAt, but got {prefetch}.')
|
346
|
+
# check prefetch
|
339
347
|
self.prefetch = prefetch
|
348
|
+
if len(self.prefetch) > 0 and not isinstance(prefetch[-1], (Prefetch, PrefetchDelayAt)):
|
349
|
+
raise TypeError(
|
350
|
+
f'The last element of prefetch should be an instance of {Prefetch} or {PrefetchDelayAt}, '
|
351
|
+
f'but got {prefetch[-1]}.'
|
352
|
+
)
|
340
353
|
|
341
354
|
# check out
|
342
355
|
if not isinstance(out, SynOut):
|
@@ -354,41 +367,120 @@ class CurrentProj(Interaction):
|
|
354
367
|
|
355
368
|
@call_order(2)
|
356
369
|
def init_state(self, *args, **kwargs):
|
357
|
-
|
370
|
+
for prefetch in self.prefetch:
|
371
|
+
maybe_init_prefetch(prefetch, *args, **kwargs)
|
358
372
|
|
359
373
|
def update(self, *x):
|
360
|
-
|
361
|
-
|
374
|
+
for prefetch in self.prefetch:
|
375
|
+
x = (call_module(prefetch, *x),)
|
376
|
+
x = self.comm(*x)
|
362
377
|
self.out.bind_cond(x)
|
363
378
|
|
364
379
|
|
365
|
-
class
|
380
|
+
class align_pre_projection(Projection):
|
366
381
|
"""
|
382
|
+
Represents a pre-synaptic alignment projection mechanism.
|
383
|
+
|
384
|
+
This class inherits from the `Projection` base class and is designed to
|
385
|
+
manage the pre-synaptic alignment process in neural network simulations.
|
386
|
+
It takes into account pre-synaptic dynamics, synaptic properties, delays,
|
387
|
+
communication functions, synaptic outputs, post-synaptic dynamics, and
|
388
|
+
short-term plasticity.
|
389
|
+
|
390
|
+
Attributes:
|
391
|
+
pre (Dynamics): The pre-synaptic dynamics object.
|
392
|
+
syn (Synapse): The synaptic object after pre-synaptic alignment.
|
393
|
+
delay (u.Quantity[u.second]): The output delay from the synapse.
|
394
|
+
projection (CurrentProj): The current projection object handling communication,
|
395
|
+
output, and post-synaptic dynamics.
|
396
|
+
stp (ShortTermPlasticity, optional): The short-term plasticity object,
|
397
|
+
defaults to None.
|
367
398
|
"""
|
368
|
-
__module__ = 'brainstate.nn'
|
369
399
|
|
370
400
|
def __init__(
|
371
401
|
self,
|
402
|
+
*spike_generator,
|
403
|
+
syn: Dynamics,
|
372
404
|
comm: Callable,
|
373
405
|
out: SynOut,
|
374
406
|
post: Dynamics,
|
407
|
+
stp: ShortTermPlasticity = None,
|
375
408
|
):
|
376
|
-
super().__init__(
|
409
|
+
super().__init__()
|
377
410
|
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
self.
|
411
|
+
self.spike_generator = _check_modules(*spike_generator)
|
412
|
+
self.projection = CurrentProj(comm=comm, out=out, post=post)
|
413
|
+
self.syn = syn
|
414
|
+
self.stp = stp
|
382
415
|
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
post.add_current_input(self.name, out)
|
416
|
+
@call_order(2)
|
417
|
+
def init_state(self, *args, **kwargs):
|
418
|
+
for module in self.spike_generator:
|
419
|
+
maybe_init_prefetch(module, *args, **kwargs)
|
388
420
|
|
389
|
-
|
390
|
-
self.
|
421
|
+
def update(self, *x):
|
422
|
+
for fun in self.spike_generator:
|
423
|
+
x = fun(*x)
|
424
|
+
if isinstance(x, (tuple, list)):
|
425
|
+
x = tuple(x)
|
426
|
+
else:
|
427
|
+
x = (x,)
|
428
|
+
assert len(x) == 1, "Spike generator must return a single value or a tuple/list of values"
|
429
|
+
x = brainevent.BinaryArray(x[0]) # Ensure input is a BinaryFloat for spike generation
|
430
|
+
if self.stp is not None:
|
431
|
+
x = brainevent.MaskedFloat(self.stp(x)) # Ensure STP output is a MaskedFloat
|
432
|
+
x = self.syn(x) # Apply pre-synaptic alignment
|
433
|
+
return self.projection(x)
|
434
|
+
|
435
|
+
|
436
|
+
class align_post_projection(Projection):
|
437
|
+
"""
|
438
|
+
Represents a post-synaptic alignment projection mechanism.
|
391
439
|
|
392
|
-
|
393
|
-
|
394
|
-
|
440
|
+
This class inherits from the `Projection` base class and is designed to
|
441
|
+
manage the post-synaptic alignment process in neural network simulations.
|
442
|
+
It takes into account spike generators, communication functions, synaptic
|
443
|
+
properties, synaptic outputs, post-synaptic dynamics, and short-term plasticity.
|
444
|
+
|
445
|
+
Args:
|
446
|
+
*spike_generator: Callable(s) that generate spike events or transform input spikes.
|
447
|
+
comm (Callable): Communication function for the projection.
|
448
|
+
syn (Union[AlignPost, ParamDescriber[AlignPost]]): The post-synaptic alignment object or its parameter describer.
|
449
|
+
out (Union[SynOut, ParamDescriber[SynOut]]): The synaptic output object or its parameter describer.
|
450
|
+
post (Dynamics): The post-synaptic dynamics object.
|
451
|
+
stp (ShortTermPlasticity, optional): The short-term plasticity object, defaults to None.
|
452
|
+
|
453
|
+
"""
|
454
|
+
|
455
|
+
def __init__(
|
456
|
+
self,
|
457
|
+
*spike_generator,
|
458
|
+
comm: Callable,
|
459
|
+
syn: Union[AlignPost, ParamDescriber[AlignPost]],
|
460
|
+
out: Union[SynOut, ParamDescriber[SynOut]],
|
461
|
+
post: Dynamics,
|
462
|
+
stp: ShortTermPlasticity = None,
|
463
|
+
):
|
464
|
+
super().__init__()
|
465
|
+
|
466
|
+
self.spike_generator = _check_modules(*spike_generator)
|
467
|
+
self.projection = AlignPostProj(comm=comm, syn=syn, out=out, post=post)
|
468
|
+
self.stp = stp
|
469
|
+
|
470
|
+
@call_order(2)
|
471
|
+
def init_state(self, *args, **kwargs):
|
472
|
+
for module in self.spike_generator:
|
473
|
+
maybe_init_prefetch(module, *args, **kwargs)
|
474
|
+
|
475
|
+
def update(self, *x):
|
476
|
+
for fun in self.spike_generator:
|
477
|
+
x = fun(*x)
|
478
|
+
if isinstance(x, (tuple, list)):
|
479
|
+
x = tuple(x)
|
480
|
+
else:
|
481
|
+
x = (x,)
|
482
|
+
assert len(x) == 1, "Spike generator must return a single value or a tuple/list of values"
|
483
|
+
x = brainevent.BinaryArray(x[0]) # Ensure input is a BinaryFloat for spike generation
|
484
|
+
if self.stp is not None:
|
485
|
+
x = brainevent.MaskedFloat(self.stp(x)) # Ensure STP output is a MaskedFloat
|
486
|
+
return self.projection(x)
|
brainstate/nn/_synapse.py
CHANGED
@@ -307,9 +307,6 @@ class Alpha(Synapse):
|
|
307
307
|
self.h.value = self.sum_delta_inputs(h)
|
308
308
|
if x is not None:
|
309
309
|
self.h.value += x
|
310
|
-
return self.update_return()
|
311
|
-
|
312
|
-
def update_return(self) -> PyTree:
|
313
310
|
return self.g.value
|
314
311
|
|
315
312
|
|
@@ -394,7 +391,7 @@ class AMPA(Synapse):
|
|
394
391
|
beta: ArrayLike = 0.18 / u.ms,
|
395
392
|
T: ArrayLike = 0.5 * u.mM,
|
396
393
|
T_dur: ArrayLike = 0.5 * u.ms,
|
397
|
-
g_initializer: ArrayLike | Callable = init.ZeroInit(),
|
394
|
+
g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
|
398
395
|
):
|
399
396
|
super().__init__(name=name, in_size=in_size)
|
400
397
|
|
@@ -413,14 +410,12 @@ class AMPA(Synapse):
|
|
413
410
|
self.g.value = init.param(self.g_initializer, self.varshape, batch_or_mode)
|
414
411
|
self.spike_arrival_time.value = init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_or_mode)
|
415
412
|
|
416
|
-
def dg(self, g, t, TT):
|
417
|
-
return self.alpha * TT * (1 - g) - self.beta * g
|
418
|
-
|
419
413
|
def update(self, pre_spike):
|
420
414
|
t = environ.get('t')
|
421
415
|
self.spike_arrival_time.value = u.math.where(pre_spike, t, self.spike_arrival_time.value)
|
422
416
|
TT = ((t - self.spike_arrival_time.value) < self.T_duration) * self.T
|
423
|
-
|
417
|
+
dg = lambda g: self.alpha * TT * (1 - g) - self.beta * g
|
418
|
+
self.g.value = exp_euler_step(dg, self.g.value)
|
424
419
|
return self.update_return()
|
425
420
|
|
426
421
|
def update_return(self) -> PyTree:
|
@@ -507,7 +502,7 @@ class GABAa(AMPA):
|
|
507
502
|
beta: ArrayLike = 0.18 / u.ms,
|
508
503
|
T: ArrayLike = 1.0 * u.mM,
|
509
504
|
T_dur: ArrayLike = 1.0 * u.ms,
|
510
|
-
g_initializer: ArrayLike | Callable = init.ZeroInit(),
|
505
|
+
g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
|
511
506
|
):
|
512
507
|
super().__init__(
|
513
508
|
alpha=alpha,
|
@@ -518,3 +513,4 @@ class GABAa(AMPA):
|
|
518
513
|
in_size=in_size,
|
519
514
|
g_initializer=g_initializer
|
520
515
|
)
|
516
|
+
|