brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -156
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -502
- brainstate/nn/_delay_test.py +238 -184
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1343
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1119
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.8.dist-info/RECORD +0 -132
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
brainstate/nn/_inputs.py
CHANGED
@@ -1,608 +1,608 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from typing import Union, Optional, Sequence, Callable
|
17
|
-
|
18
|
-
import brainunit as u
|
19
|
-
import jax
|
20
|
-
import numpy as np
|
21
|
-
|
22
|
-
from brainstate import environ, init, random
|
23
|
-
from brainstate._state import ShortTermState, State, maybe_state
|
24
|
-
from brainstate.compile import while_loop
|
25
|
-
from brainstate.typing import ArrayLike, Size, DTypeLike
|
26
|
-
from ._dynamics import Dynamics, Prefetch
|
27
|
-
from ._module import Module
|
28
|
-
|
29
|
-
__all__ = [
|
30
|
-
'SpikeTime',
|
31
|
-
'PoissonSpike',
|
32
|
-
'PoissonEncoder',
|
33
|
-
'PoissonInput',
|
34
|
-
'poisson_input',
|
35
|
-
]
|
36
|
-
|
37
|
-
|
38
|
-
class SpikeTime(Dynamics):
|
39
|
-
"""The input neuron group characterized by spikes emitting at given times.
|
40
|
-
|
41
|
-
>>> # Get 2 neurons, firing spikes at 10 ms and 20 ms.
|
42
|
-
>>> SpikeTime(2, times=[10, 20])
|
43
|
-
>>> # or
|
44
|
-
>>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms.
|
45
|
-
>>> SpikeTime(2, times=[10, 20], indices=[0, 0])
|
46
|
-
>>> # or
|
47
|
-
>>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms.
|
48
|
-
>>> SpikeTime(2, times=[10, 20, 30], indices=[0, 1, 0])
|
49
|
-
>>> # or
|
50
|
-
>>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire;
|
51
|
-
>>> # at 30 ms, neuron 1 fires.
|
52
|
-
>>> SpikeTime(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])
|
53
|
-
|
54
|
-
Parameters
|
55
|
-
----------
|
56
|
-
in_size : int, tuple, list
|
57
|
-
The neuron group geometry.
|
58
|
-
indices : list, tuple, ArrayType
|
59
|
-
The neuron indices at each time point to emit spikes.
|
60
|
-
times : list, tuple, ArrayType
|
61
|
-
The time points which generate the spikes.
|
62
|
-
name : str, optional
|
63
|
-
The name of the dynamic system.
|
64
|
-
"""
|
65
|
-
|
66
|
-
def __init__(
|
67
|
-
self,
|
68
|
-
in_size: Size,
|
69
|
-
indices: Union[Sequence, ArrayLike],
|
70
|
-
times: Union[Sequence, ArrayLike],
|
71
|
-
spk_type: DTypeLike = bool,
|
72
|
-
name: Optional[str] = None,
|
73
|
-
need_sort: bool = True,
|
74
|
-
):
|
75
|
-
super().__init__(in_size=in_size, name=name)
|
76
|
-
|
77
|
-
# parameters
|
78
|
-
if len(indices) != len(times):
|
79
|
-
raise ValueError(f'The length of "indices" and "times" must be the same. '
|
80
|
-
f'However, we got {len(indices)} != {len(times)}.')
|
81
|
-
self.num_times = len(times)
|
82
|
-
self.spk_type = spk_type
|
83
|
-
|
84
|
-
# data about times and indices
|
85
|
-
self.times = u.math.asarray(times)
|
86
|
-
self.indices = u.math.asarray(indices, dtype=environ.ditype())
|
87
|
-
if need_sort:
|
88
|
-
sort_idx = u.math.argsort(self.times)
|
89
|
-
self.indices = self.indices[sort_idx]
|
90
|
-
self.times = self.times[sort_idx]
|
91
|
-
|
92
|
-
def init_state(self, *args, **kwargs):
|
93
|
-
self.i = ShortTermState(-1)
|
94
|
-
|
95
|
-
def reset_state(self, batch_size=None, **kwargs):
|
96
|
-
self.i.value = -1
|
97
|
-
|
98
|
-
def update(self):
|
99
|
-
t = environ.get('t')
|
100
|
-
|
101
|
-
def _cond_fun(spikes):
|
102
|
-
i = self.i.value
|
103
|
-
return u.math.logical_and(i < self.num_times, t >= self.times[i])
|
104
|
-
|
105
|
-
def _body_fun(spikes):
|
106
|
-
i = self.i.value
|
107
|
-
spikes = spikes.at[..., self.indices[i]].set(True)
|
108
|
-
self.i.value += 1
|
109
|
-
return spikes
|
110
|
-
|
111
|
-
spike = u.math.zeros(self.varshape, dtype=self.spk_type)
|
112
|
-
spike = while_loop(_cond_fun, _body_fun, spike)
|
113
|
-
return spike
|
114
|
-
|
115
|
-
|
116
|
-
class PoissonSpike(Dynamics):
|
117
|
-
"""
|
118
|
-
Poisson Neuron Group.
|
119
|
-
"""
|
120
|
-
|
121
|
-
def __init__(
|
122
|
-
self,
|
123
|
-
in_size: Size,
|
124
|
-
freqs: Union[ArrayLike, Callable],
|
125
|
-
spk_type: DTypeLike = bool,
|
126
|
-
name: Optional[str] = None,
|
127
|
-
):
|
128
|
-
super().__init__(in_size=in_size, name=name)
|
129
|
-
|
130
|
-
self.spk_type = spk_type
|
131
|
-
|
132
|
-
# parameters
|
133
|
-
self.freqs = init.param(freqs, self.varshape, allow_none=False)
|
134
|
-
|
135
|
-
def update(self):
|
136
|
-
spikes = random.rand(*self.varshape) <= (self.freqs * environ.get_dt())
|
137
|
-
spikes = u.math.asarray(spikes, dtype=self.spk_type)
|
138
|
-
return spikes
|
139
|
-
|
140
|
-
|
141
|
-
class PoissonEncoder(Dynamics):
|
142
|
-
r"""Poisson spike encoder for converting firing rates to spike trains.
|
143
|
-
|
144
|
-
This class implements a Poisson process to generate spikes based on provided
|
145
|
-
firing rates. Unlike the PoissonSpike class, this encoder accepts firing rates
|
146
|
-
as input during the update step rather than having them fixed at initialization.
|
147
|
-
|
148
|
-
The spike generation follows a Poisson process where the probability of a spike
|
149
|
-
in each time step is proportional to the firing rate and the simulation time step:
|
150
|
-
|
151
|
-
$$
|
152
|
-
P(\text{spike}) = \text{rate} \cdot \text{dt}
|
153
|
-
$$
|
154
|
-
|
155
|
-
For each neuron and time step, the encoder draws a random number from a uniform
|
156
|
-
distribution [0,1] and generates a spike if the number is less than or equal to
|
157
|
-
the spiking probability.
|
158
|
-
|
159
|
-
Parameters
|
160
|
-
----------
|
161
|
-
in_size : Size
|
162
|
-
Size of the input to the encoder, defining the shape of the output spike train.
|
163
|
-
spk_type : DTypeLike, default=bool
|
164
|
-
Data type for the generated spikes. Typically boolean for binary spikes.
|
165
|
-
name : str, optional
|
166
|
-
Name of the encoder module.
|
167
|
-
|
168
|
-
Examples
|
169
|
-
--------
|
170
|
-
>>> import brainstate as bs
|
171
|
-
>>> import brainunit as u
|
172
|
-
>>> import numpy as np
|
173
|
-
>>>
|
174
|
-
>>> # Create a Poisson encoder for 10 neurons
|
175
|
-
>>> encoder = bs.nn.PoissonEncoder(10)
|
176
|
-
>>>
|
177
|
-
>>> # Generate spikes with varying firing rates
|
178
|
-
>>> rates = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100]) * u.Hz
|
179
|
-
>>> spikes = encoder.update(rates)
|
180
|
-
>>>
|
181
|
-
>>> # Use in a more complex processing pipeline
|
182
|
-
>>> # First, generate rate-coded output from an analog signal
|
183
|
-
>>> analog_values = np.random.rand(10) * 100 # values between 0 and 100
|
184
|
-
>>> firing_rates = analog_values * u.Hz # convert to firing rates
|
185
|
-
>>> spike_train = encoder.update(firing_rates)
|
186
|
-
>>>
|
187
|
-
>>> # Feed the spikes into a spiking neural network
|
188
|
-
>>> neuron_layer = bs.nn.LIF(10)
|
189
|
-
>>> neuron_layer.init_state(batch_size=1)
|
190
|
-
>>> output_spikes = neuron_layer.update(spike_train)
|
191
|
-
|
192
|
-
Notes
|
193
|
-
-----
|
194
|
-
- This encoder is particularly useful for rate-to-spike conversion in neuromorphic
|
195
|
-
computing applications and sensory encoding tasks.
|
196
|
-
- The statistical properties of the generated spike trains follow a Poisson process,
|
197
|
-
where the inter-spike intervals are exponentially distributed.
|
198
|
-
- For small time steps (dt), the number of spikes in a time window T approximately
|
199
|
-
follows a Poisson distribution with parameter λ = rate * T.
|
200
|
-
- Unlike PoissonSpike which has fixed rates, this encoder allows dynamic rate changes
|
201
|
-
with every update call, making it suitable for encoding time-varying signals.
|
202
|
-
- The independence of spike generation between time steps results in renewal process
|
203
|
-
statistics without memory of previous spiking history.
|
204
|
-
"""
|
205
|
-
|
206
|
-
def __init__(
|
207
|
-
self,
|
208
|
-
in_size: Size,
|
209
|
-
spk_type: DTypeLike = bool,
|
210
|
-
name: Optional[str] = None,
|
211
|
-
):
|
212
|
-
super().__init__(in_size=in_size, name=name)
|
213
|
-
self.spk_type = spk_type
|
214
|
-
|
215
|
-
def update(self, freqs: ArrayLike):
|
216
|
-
spikes = random.rand(*self.varshape) <= (freqs * environ.get_dt())
|
217
|
-
spikes = u.math.asarray(spikes, dtype=self.spk_type)
|
218
|
-
return spikes
|
219
|
-
|
220
|
-
|
221
|
-
class PoissonInput(Module):
|
222
|
-
r"""Poisson Input to the given state variable.
|
223
|
-
|
224
|
-
This class provides a way to add independent Poisson-distributed spiking input
|
225
|
-
to a target state variable. For large numbers of inputs, this implementation is
|
226
|
-
computationally more efficient than creating separate Poisson spike generators.
|
227
|
-
|
228
|
-
The synaptic events are generated randomly during simulation runtime and are not
|
229
|
-
preloaded or stored in memory, which improves memory efficiency for large-scale
|
230
|
-
simulations. All inputs target the same variable with the same frequency and
|
231
|
-
synaptic weight.
|
232
|
-
|
233
|
-
The Poisson process generates spikes with probability based on the frequency and
|
234
|
-
simulation time step:
|
235
|
-
|
236
|
-
$$
|
237
|
-
P(\text{spike}) = \text{freq} \cdot \text{dt}
|
238
|
-
$$
|
239
|
-
|
240
|
-
For computational efficiency, two different methods are used for spike generation:
|
241
|
-
|
242
|
-
1. For large numbers of inputs, a normal approximation:
|
243
|
-
$$
|
244
|
-
\text{inputs} \sim \mathcal{N}(\mu, \sigma^2)
|
245
|
-
$$
|
246
|
-
where $\mu = \text{num\_input} \cdot p$ and $\sigma^2 = \text{num\_input} \cdot p \cdot (1-p)$
|
247
|
-
|
248
|
-
2. For smaller numbers, a direct binomial sampling:
|
249
|
-
$$
|
250
|
-
\text{inputs} \sim \text{Binomial}(\text{num\_input}, p)
|
251
|
-
$$
|
252
|
-
|
253
|
-
where $p = \text{freq} \cdot \text{dt}$ in both cases.
|
254
|
-
|
255
|
-
Parameters
|
256
|
-
----------
|
257
|
-
target : Prefetch
|
258
|
-
The variable that is targeted by this input. Should be an instance of
|
259
|
-
:py:class:`brainstate.State` that's prefetched via the target mechanism.
|
260
|
-
indices : Union[np.ndarray, jax.Array]
|
261
|
-
Indices of the target to receive input. If None, input is applied to the entire target.
|
262
|
-
num_input : int
|
263
|
-
The number of independent Poisson input sources.
|
264
|
-
freq : Union[int, float]
|
265
|
-
The firing frequency of each input source in Hz.
|
266
|
-
weight : ndarray, float, or brainunit.Quantity
|
267
|
-
The synaptic weight of each input spike.
|
268
|
-
name : Optional[str], optional
|
269
|
-
The name of this module.
|
270
|
-
|
271
|
-
Examples
|
272
|
-
--------
|
273
|
-
>>> import brainstate as bs
|
274
|
-
>>> import brainunit as u
|
275
|
-
>>> import numpy as np
|
276
|
-
>>>
|
277
|
-
>>> # Create a neuron group with membrane potential
|
278
|
-
>>> neuron = bs.nn.LIF(100)
|
279
|
-
>>> neuron.init_state(batch_size=1)
|
280
|
-
>>>
|
281
|
-
>>> # Add Poisson input to all neurons
|
282
|
-
>>> poisson_in = bs.nn.PoissonInput(
|
283
|
-
... target=neuron.V,
|
284
|
-
... indices=None,
|
285
|
-
... num_input=200,
|
286
|
-
... freq=50 * u.Hz,
|
287
|
-
... weight=0.1 * u.mV
|
288
|
-
... )
|
289
|
-
>>>
|
290
|
-
>>> # Add Poisson input only to specific neurons
|
291
|
-
>>> indices = np.array([0, 10, 20, 30])
|
292
|
-
>>> specific_input = bs.nn.PoissonInput(
|
293
|
-
... target=neuron.V,
|
294
|
-
... indices=indices,
|
295
|
-
... num_input=50,
|
296
|
-
... freq=100 * u.Hz,
|
297
|
-
... weight=0.2 * u.mV
|
298
|
-
... )
|
299
|
-
>>>
|
300
|
-
>>> # Run simulation with the inputs
|
301
|
-
>>> for t in range(100):
|
302
|
-
... poisson_in.update()
|
303
|
-
... specific_input.update()
|
304
|
-
... neuron.update()
|
305
|
-
|
306
|
-
Notes
|
307
|
-
-----
|
308
|
-
- The Poisson inputs are statistically independent between update steps and across
|
309
|
-
target neurons.
|
310
|
-
- This implementation is particularly efficient for large numbers of inputs or targets.
|
311
|
-
- For very sparse connectivity patterns, consider using individual PoissonSpike neurons
|
312
|
-
with specific connectivity patterns instead.
|
313
|
-
- The update method internally calls the poisson_input function which handles the
|
314
|
-
spike generation and target state updates.
|
315
|
-
"""
|
316
|
-
|
317
|
-
def __init__(
|
318
|
-
self,
|
319
|
-
target: Prefetch,
|
320
|
-
indices: Union[np.ndarray, jax.Array],
|
321
|
-
num_input: int,
|
322
|
-
freq: u.Quantity[u.Hz],
|
323
|
-
weight: Union[jax.typing.ArrayLike, u.Quantity],
|
324
|
-
name: Optional[str] = None,
|
325
|
-
):
|
326
|
-
super().__init__(name=name)
|
327
|
-
|
328
|
-
self.target = target
|
329
|
-
self.indices = indices
|
330
|
-
self.num_input = num_input
|
331
|
-
self.freq = freq
|
332
|
-
self.weight = weight
|
333
|
-
|
334
|
-
def update(self):
|
335
|
-
target_state = getattr(self.target.module, self.target.item)
|
336
|
-
|
337
|
-
# generate Poisson input
|
338
|
-
poisson_input(
|
339
|
-
self.freq,
|
340
|
-
self.num_input,
|
341
|
-
self.weight,
|
342
|
-
target_state,
|
343
|
-
self.indices,
|
344
|
-
)
|
345
|
-
|
346
|
-
|
347
|
-
def poisson_input(
|
348
|
-
freq: u.Quantity[u.Hz],
|
349
|
-
num_input: int,
|
350
|
-
weight: Union[jax.typing.ArrayLike, u.Quantity],
|
351
|
-
target: State,
|
352
|
-
indices: Optional[Union[np.ndarray, jax.Array]] = None,
|
353
|
-
refractory: Optional[Union[jax.Array]] = None,
|
354
|
-
):
|
355
|
-
r"""Generates Poisson-distributed input spikes to a target state variable.
|
356
|
-
|
357
|
-
This function simulates Poisson input to a given state, updating the target
|
358
|
-
variable with generated spikes based on the specified frequency, number of inputs,
|
359
|
-
and synaptic weight. The input can be applied to specific indices of the target
|
360
|
-
or to the entire target if indices are not provided.
|
361
|
-
|
362
|
-
The function uses two different methods to generate the Poisson-distributed input:
|
363
|
-
1. For large numbers of inputs (a > 5 and b > 5), a normal approximation is used
|
364
|
-
2. For smaller numbers, a direct binomial sampling approach is used
|
365
|
-
|
366
|
-
Mathematical model for Poisson input:
|
367
|
-
$$
|
368
|
-
P(\text{spike}) = \text{freq} \cdot \text{dt}
|
369
|
-
$$
|
370
|
-
|
371
|
-
For the normal approximation (when a > 5 and b > 5):
|
372
|
-
$$
|
373
|
-
\text{inputs} \sim \mathcal{N}(a, b \cdot p)
|
374
|
-
$$
|
375
|
-
where:
|
376
|
-
$$
|
377
|
-
a = \text{num\_input} \cdot p
|
378
|
-
$$
|
379
|
-
$$
|
380
|
-
b = \text{num\_input} \cdot (1 - p)
|
381
|
-
$$
|
382
|
-
$$
|
383
|
-
p = \text{freq} \cdot \text{dt}
|
384
|
-
$$
|
385
|
-
|
386
|
-
For direct binomial sampling (when a ≤ 5 or b ≤ 5):
|
387
|
-
$$
|
388
|
-
\text{inputs} \sim \text{Binomial}(\text{num\_input}, p)
|
389
|
-
$$
|
390
|
-
|
391
|
-
Parameters
|
392
|
-
----------
|
393
|
-
freq : u.Quantity[u.Hz]
|
394
|
-
The frequency of the Poisson input in Hertz.
|
395
|
-
num_input : int
|
396
|
-
The number of input channels or neurons generating the Poisson spikes.
|
397
|
-
weight : u.Quantity
|
398
|
-
The synaptic weight applied to each spike.
|
399
|
-
target : State
|
400
|
-
The target state variable to which the Poisson input is applied.
|
401
|
-
indices : Optional[Union[np.ndarray, jax.Array]], optional
|
402
|
-
Specific indices of the target to apply the input. If None, the input is applied
|
403
|
-
to the entire target.
|
404
|
-
refractory : Optional[Union[jax.Array]], optional
|
405
|
-
A boolean array indicating which parts of the target are in a refractory state
|
406
|
-
and should not be updated. Should be the same length as the target.
|
407
|
-
|
408
|
-
Examples
|
409
|
-
--------
|
410
|
-
>>> import brainstate as bs
|
411
|
-
>>> import brainunit as u
|
412
|
-
>>> import numpy as np
|
413
|
-
>>>
|
414
|
-
>>> # Create a membrane potential state
|
415
|
-
>>> V = bs.HiddenState(np.zeros(100) * u.mV)
|
416
|
-
>>>
|
417
|
-
>>> # Add Poisson input to all neurons at 50 Hz
|
418
|
-
>>> bs.nn.poisson_input(
|
419
|
-
... freq=50 * u.Hz,
|
420
|
-
... num_input=200,
|
421
|
-
... weight=0.1 * u.mV,
|
422
|
-
... target=V
|
423
|
-
... )
|
424
|
-
>>>
|
425
|
-
>>> # Apply Poisson input only to a subset of neurons
|
426
|
-
>>> indices = np.array([0, 10, 20, 30])
|
427
|
-
>>> bs.nn.poisson_input(
|
428
|
-
... freq=100 * u.Hz,
|
429
|
-
... num_input=50,
|
430
|
-
... weight=0.2 * u.mV,
|
431
|
-
... target=V,
|
432
|
-
... indices=indices
|
433
|
-
... )
|
434
|
-
>>>
|
435
|
-
>>> # Apply input with refractory mask
|
436
|
-
>>> refractory = np.zeros(100, dtype=bool)
|
437
|
-
>>> refractory[40:60] = True # neurons 40-59 are in refractory period
|
438
|
-
>>> bs.nn.poisson_input(
|
439
|
-
... freq=75 * u.Hz,
|
440
|
-
... num_input=100,
|
441
|
-
... weight=0.15 * u.mV,
|
442
|
-
... target=V,
|
443
|
-
... refractory=refractory
|
444
|
-
... )
|
445
|
-
|
446
|
-
Notes
|
447
|
-
-----
|
448
|
-
- The function automatically switches between normal approximation and binomial
|
449
|
-
sampling based on the input parameters to optimize computation efficiency.
|
450
|
-
- For large numbers of inputs, the normal approximation provides significant
|
451
|
-
performance improvements.
|
452
|
-
- The weight parameter is applied uniformly to all generated spikes.
|
453
|
-
- When refractory is provided, the corresponding target elements are not updated.
|
454
|
-
"""
|
455
|
-
freq = maybe_state(freq)
|
456
|
-
weight = maybe_state(weight)
|
457
|
-
|
458
|
-
assert isinstance(target, State), 'The target must be a State.'
|
459
|
-
p = freq * environ.get_dt()
|
460
|
-
p = p.to_decimal() if isinstance(p, u.Quantity) else p
|
461
|
-
a = num_input * p
|
462
|
-
b = num_input * (1 - p)
|
463
|
-
tar_val = target.value
|
464
|
-
cond = u.math.logical_and(a > 5, b > 5)
|
465
|
-
|
466
|
-
if indices is None:
|
467
|
-
# generate Poisson input
|
468
|
-
branch1 = jax.tree.map(
|
469
|
-
lambda tar: random.normal(
|
470
|
-
a,
|
471
|
-
b * p,
|
472
|
-
tar.shape,
|
473
|
-
dtype=tar.dtype
|
474
|
-
),
|
475
|
-
tar_val,
|
476
|
-
is_leaf=u.math.is_quantity
|
477
|
-
)
|
478
|
-
branch2 = jax.tree.map(
|
479
|
-
lambda tar: random.binomial(
|
480
|
-
num_input,
|
481
|
-
p,
|
482
|
-
tar.shape,
|
483
|
-
check_valid=False,
|
484
|
-
dtype=tar.dtype
|
485
|
-
),
|
486
|
-
tar_val,
|
487
|
-
is_leaf=u.math.is_quantity,
|
488
|
-
)
|
489
|
-
|
490
|
-
inp = jax.tree.map(
|
491
|
-
lambda b1, b2: u.math.where(cond, b1, b2),
|
492
|
-
branch1,
|
493
|
-
branch2,
|
494
|
-
is_leaf=u.math.is_quantity,
|
495
|
-
)
|
496
|
-
|
497
|
-
# inp = jax.lax.cond(
|
498
|
-
# cond,
|
499
|
-
# lambda rand_key: jax.tree.map(
|
500
|
-
# lambda tar: random.normal(
|
501
|
-
# a,
|
502
|
-
# b * p,
|
503
|
-
# tar.shape,
|
504
|
-
# key=rand_key,
|
505
|
-
# dtype=tar.dtype
|
506
|
-
# ),
|
507
|
-
# tar_val,
|
508
|
-
# is_leaf=u.math.is_quantity
|
509
|
-
# ),
|
510
|
-
# lambda rand_key: jax.tree.map(
|
511
|
-
# lambda tar: random.binomial(
|
512
|
-
# num_input,
|
513
|
-
# p,
|
514
|
-
# tar.shape,
|
515
|
-
# key=rand_key,
|
516
|
-
# check_valid=False,
|
517
|
-
# dtype=tar.dtype
|
518
|
-
# ),
|
519
|
-
# tar_val,
|
520
|
-
# is_leaf=u.math.is_quantity,
|
521
|
-
# ),
|
522
|
-
# random.split_key()
|
523
|
-
# )
|
524
|
-
|
525
|
-
# update target variable
|
526
|
-
data = jax.tree.map(
|
527
|
-
lambda tar, x: tar + x * weight,
|
528
|
-
target.value,
|
529
|
-
inp,
|
530
|
-
is_leaf=u.math.is_quantity
|
531
|
-
)
|
532
|
-
|
533
|
-
else:
|
534
|
-
# generate Poisson input
|
535
|
-
branch1 = jax.tree.map(
|
536
|
-
lambda tar: random.normal(
|
537
|
-
a,
|
538
|
-
b * p,
|
539
|
-
tar[indices].shape,
|
540
|
-
dtype=tar.dtype
|
541
|
-
),
|
542
|
-
tar_val,
|
543
|
-
is_leaf=u.math.is_quantity
|
544
|
-
)
|
545
|
-
branch2 = jax.tree.map(
|
546
|
-
lambda tar: random.binomial(
|
547
|
-
num_input,
|
548
|
-
p,
|
549
|
-
tar[indices].shape,
|
550
|
-
check_valid=False,
|
551
|
-
dtype=tar.dtype
|
552
|
-
),
|
553
|
-
tar_val,
|
554
|
-
is_leaf=u.math.is_quantity
|
555
|
-
)
|
556
|
-
|
557
|
-
inp = jax.tree.map(
|
558
|
-
lambda b1, b2: u.math.where(cond, b1, b2),
|
559
|
-
branch1,
|
560
|
-
branch2,
|
561
|
-
is_leaf=u.math.is_quantity,
|
562
|
-
)
|
563
|
-
|
564
|
-
# inp = jax.lax.cond(
|
565
|
-
# cond,
|
566
|
-
# lambda rand_key: jax.tree.map(
|
567
|
-
# lambda tar: random.normal(
|
568
|
-
# a,
|
569
|
-
# b * p,
|
570
|
-
# tar[indices].shape,
|
571
|
-
# key=rand_key,
|
572
|
-
# dtype=tar.dtype
|
573
|
-
# ),
|
574
|
-
# tar_val,
|
575
|
-
# is_leaf=u.math.is_quantity
|
576
|
-
# ),
|
577
|
-
# lambda rand_key: jax.tree.map(
|
578
|
-
# lambda tar: random.binomial(
|
579
|
-
# num_input,
|
580
|
-
# p,
|
581
|
-
# tar[indices].shape,
|
582
|
-
# key=rand_key,
|
583
|
-
# check_valid=False,
|
584
|
-
# dtype=tar.dtype
|
585
|
-
# ),
|
586
|
-
# tar_val,
|
587
|
-
# is_leaf=u.math.is_quantity
|
588
|
-
# ),
|
589
|
-
# random.split_key()
|
590
|
-
# )
|
591
|
-
|
592
|
-
# update target variable
|
593
|
-
data = jax.tree.map(
|
594
|
-
lambda x, tar: tar.at[indices].add(x * weight),
|
595
|
-
inp,
|
596
|
-
tar_val,
|
597
|
-
is_leaf=u.math.is_quantity
|
598
|
-
)
|
599
|
-
|
600
|
-
if refractory is not None:
|
601
|
-
target.value = jax.tree.map(
|
602
|
-
lambda x, tar: u.math.where(refractory, tar, x),
|
603
|
-
data,
|
604
|
-
tar_val,
|
605
|
-
is_leaf=u.math.is_quantity
|
606
|
-
)
|
607
|
-
else:
|
608
|
-
target.value = data
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Union, Optional, Sequence, Callable
|
17
|
+
|
18
|
+
import brainunit as u
|
19
|
+
import jax
|
20
|
+
import numpy as np
|
21
|
+
|
22
|
+
from brainstate import environ, init, random
|
23
|
+
from brainstate._state import ShortTermState, State, maybe_state
|
24
|
+
from brainstate.compile import while_loop
|
25
|
+
from brainstate.typing import ArrayLike, Size, DTypeLike
|
26
|
+
from ._dynamics import Dynamics, Prefetch
|
27
|
+
from ._module import Module
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
'SpikeTime',
|
31
|
+
'PoissonSpike',
|
32
|
+
'PoissonEncoder',
|
33
|
+
'PoissonInput',
|
34
|
+
'poisson_input',
|
35
|
+
]
|
36
|
+
|
37
|
+
|
38
|
+
class SpikeTime(Dynamics):
|
39
|
+
"""The input neuron group characterized by spikes emitting at given times.
|
40
|
+
|
41
|
+
>>> # Get 2 neurons, firing spikes at 10 ms and 20 ms.
|
42
|
+
>>> SpikeTime(2, times=[10, 20])
|
43
|
+
>>> # or
|
44
|
+
>>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms.
|
45
|
+
>>> SpikeTime(2, times=[10, 20], indices=[0, 0])
|
46
|
+
>>> # or
|
47
|
+
>>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms.
|
48
|
+
>>> SpikeTime(2, times=[10, 20, 30], indices=[0, 1, 0])
|
49
|
+
>>> # or
|
50
|
+
>>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire;
|
51
|
+
>>> # at 30 ms, neuron 1 fires.
|
52
|
+
>>> SpikeTime(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])
|
53
|
+
|
54
|
+
Parameters
|
55
|
+
----------
|
56
|
+
in_size : int, tuple, list
|
57
|
+
The neuron group geometry.
|
58
|
+
indices : list, tuple, ArrayType
|
59
|
+
The neuron indices at each time point to emit spikes.
|
60
|
+
times : list, tuple, ArrayType
|
61
|
+
The time points which generate the spikes.
|
62
|
+
name : str, optional
|
63
|
+
The name of the dynamic system.
|
64
|
+
"""
|
65
|
+
|
66
|
+
def __init__(
|
67
|
+
self,
|
68
|
+
in_size: Size,
|
69
|
+
indices: Union[Sequence, ArrayLike],
|
70
|
+
times: Union[Sequence, ArrayLike],
|
71
|
+
spk_type: DTypeLike = bool,
|
72
|
+
name: Optional[str] = None,
|
73
|
+
need_sort: bool = True,
|
74
|
+
):
|
75
|
+
super().__init__(in_size=in_size, name=name)
|
76
|
+
|
77
|
+
# parameters
|
78
|
+
if len(indices) != len(times):
|
79
|
+
raise ValueError(f'The length of "indices" and "times" must be the same. '
|
80
|
+
f'However, we got {len(indices)} != {len(times)}.')
|
81
|
+
self.num_times = len(times)
|
82
|
+
self.spk_type = spk_type
|
83
|
+
|
84
|
+
# data about times and indices
|
85
|
+
self.times = u.math.asarray(times)
|
86
|
+
self.indices = u.math.asarray(indices, dtype=environ.ditype())
|
87
|
+
if need_sort:
|
88
|
+
sort_idx = u.math.argsort(self.times)
|
89
|
+
self.indices = self.indices[sort_idx]
|
90
|
+
self.times = self.times[sort_idx]
|
91
|
+
|
92
|
+
def init_state(self, *args, **kwargs):
|
93
|
+
self.i = ShortTermState(-1)
|
94
|
+
|
95
|
+
def reset_state(self, batch_size=None, **kwargs):
|
96
|
+
self.i.value = -1
|
97
|
+
|
98
|
+
def update(self):
|
99
|
+
t = environ.get('t')
|
100
|
+
|
101
|
+
def _cond_fun(spikes):
|
102
|
+
i = self.i.value
|
103
|
+
return u.math.logical_and(i < self.num_times, t >= self.times[i])
|
104
|
+
|
105
|
+
def _body_fun(spikes):
|
106
|
+
i = self.i.value
|
107
|
+
spikes = spikes.at[..., self.indices[i]].set(True)
|
108
|
+
self.i.value += 1
|
109
|
+
return spikes
|
110
|
+
|
111
|
+
spike = u.math.zeros(self.varshape, dtype=self.spk_type)
|
112
|
+
spike = while_loop(_cond_fun, _body_fun, spike)
|
113
|
+
return spike
|
114
|
+
|
115
|
+
|
116
|
+
class PoissonSpike(Dynamics):
|
117
|
+
"""
|
118
|
+
Poisson Neuron Group.
|
119
|
+
"""
|
120
|
+
|
121
|
+
def __init__(
|
122
|
+
self,
|
123
|
+
in_size: Size,
|
124
|
+
freqs: Union[ArrayLike, Callable],
|
125
|
+
spk_type: DTypeLike = bool,
|
126
|
+
name: Optional[str] = None,
|
127
|
+
):
|
128
|
+
super().__init__(in_size=in_size, name=name)
|
129
|
+
|
130
|
+
self.spk_type = spk_type
|
131
|
+
|
132
|
+
# parameters
|
133
|
+
self.freqs = init.param(freqs, self.varshape, allow_none=False)
|
134
|
+
|
135
|
+
def update(self):
|
136
|
+
spikes = random.rand(*self.varshape) <= (self.freqs * environ.get_dt())
|
137
|
+
spikes = u.math.asarray(spikes, dtype=self.spk_type)
|
138
|
+
return spikes
|
139
|
+
|
140
|
+
|
141
|
+
class PoissonEncoder(Dynamics):
|
142
|
+
r"""Poisson spike encoder for converting firing rates to spike trains.
|
143
|
+
|
144
|
+
This class implements a Poisson process to generate spikes based on provided
|
145
|
+
firing rates. Unlike the PoissonSpike class, this encoder accepts firing rates
|
146
|
+
as input during the update step rather than having them fixed at initialization.
|
147
|
+
|
148
|
+
The spike generation follows a Poisson process where the probability of a spike
|
149
|
+
in each time step is proportional to the firing rate and the simulation time step:
|
150
|
+
|
151
|
+
$$
|
152
|
+
P(\text{spike}) = \text{rate} \cdot \text{dt}
|
153
|
+
$$
|
154
|
+
|
155
|
+
For each neuron and time step, the encoder draws a random number from a uniform
|
156
|
+
distribution [0,1] and generates a spike if the number is less than or equal to
|
157
|
+
the spiking probability.
|
158
|
+
|
159
|
+
Parameters
|
160
|
+
----------
|
161
|
+
in_size : Size
|
162
|
+
Size of the input to the encoder, defining the shape of the output spike train.
|
163
|
+
spk_type : DTypeLike, default=bool
|
164
|
+
Data type for the generated spikes. Typically boolean for binary spikes.
|
165
|
+
name : str, optional
|
166
|
+
Name of the encoder module.
|
167
|
+
|
168
|
+
Examples
|
169
|
+
--------
|
170
|
+
>>> import brainstate as bs
|
171
|
+
>>> import brainunit as u
|
172
|
+
>>> import numpy as np
|
173
|
+
>>>
|
174
|
+
>>> # Create a Poisson encoder for 10 neurons
|
175
|
+
>>> encoder = bs.nn.PoissonEncoder(10)
|
176
|
+
>>>
|
177
|
+
>>> # Generate spikes with varying firing rates
|
178
|
+
>>> rates = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100]) * u.Hz
|
179
|
+
>>> spikes = encoder.update(rates)
|
180
|
+
>>>
|
181
|
+
>>> # Use in a more complex processing pipeline
|
182
|
+
>>> # First, generate rate-coded output from an analog signal
|
183
|
+
>>> analog_values = np.random.rand(10) * 100 # values between 0 and 100
|
184
|
+
>>> firing_rates = analog_values * u.Hz # convert to firing rates
|
185
|
+
>>> spike_train = encoder.update(firing_rates)
|
186
|
+
>>>
|
187
|
+
>>> # Feed the spikes into a spiking neural network
|
188
|
+
>>> neuron_layer = bs.nn.LIF(10)
|
189
|
+
>>> neuron_layer.init_state(batch_size=1)
|
190
|
+
>>> output_spikes = neuron_layer.update(spike_train)
|
191
|
+
|
192
|
+
Notes
|
193
|
+
-----
|
194
|
+
- This encoder is particularly useful for rate-to-spike conversion in neuromorphic
|
195
|
+
computing applications and sensory encoding tasks.
|
196
|
+
- The statistical properties of the generated spike trains follow a Poisson process,
|
197
|
+
where the inter-spike intervals are exponentially distributed.
|
198
|
+
- For small time steps (dt), the number of spikes in a time window T approximately
|
199
|
+
follows a Poisson distribution with parameter λ = rate * T.
|
200
|
+
- Unlike PoissonSpike which has fixed rates, this encoder allows dynamic rate changes
|
201
|
+
with every update call, making it suitable for encoding time-varying signals.
|
202
|
+
- The independence of spike generation between time steps results in renewal process
|
203
|
+
statistics without memory of previous spiking history.
|
204
|
+
"""
|
205
|
+
|
206
|
+
def __init__(
|
207
|
+
self,
|
208
|
+
in_size: Size,
|
209
|
+
spk_type: DTypeLike = bool,
|
210
|
+
name: Optional[str] = None,
|
211
|
+
):
|
212
|
+
super().__init__(in_size=in_size, name=name)
|
213
|
+
self.spk_type = spk_type
|
214
|
+
|
215
|
+
def update(self, freqs: ArrayLike):
|
216
|
+
spikes = random.rand(*self.varshape) <= (freqs * environ.get_dt())
|
217
|
+
spikes = u.math.asarray(spikes, dtype=self.spk_type)
|
218
|
+
return spikes
|
219
|
+
|
220
|
+
|
221
|
+
class PoissonInput(Module):
|
222
|
+
r"""Poisson Input to the given state variable.
|
223
|
+
|
224
|
+
This class provides a way to add independent Poisson-distributed spiking input
|
225
|
+
to a target state variable. For large numbers of inputs, this implementation is
|
226
|
+
computationally more efficient than creating separate Poisson spike generators.
|
227
|
+
|
228
|
+
The synaptic events are generated randomly during simulation runtime and are not
|
229
|
+
preloaded or stored in memory, which improves memory efficiency for large-scale
|
230
|
+
simulations. All inputs target the same variable with the same frequency and
|
231
|
+
synaptic weight.
|
232
|
+
|
233
|
+
The Poisson process generates spikes with probability based on the frequency and
|
234
|
+
simulation time step:
|
235
|
+
|
236
|
+
$$
|
237
|
+
P(\text{spike}) = \text{freq} \cdot \text{dt}
|
238
|
+
$$
|
239
|
+
|
240
|
+
For computational efficiency, two different methods are used for spike generation:
|
241
|
+
|
242
|
+
1. For large numbers of inputs, a normal approximation:
|
243
|
+
$$
|
244
|
+
\text{inputs} \sim \mathcal{N}(\mu, \sigma^2)
|
245
|
+
$$
|
246
|
+
where $\mu = \text{num\_input} \cdot p$ and $\sigma^2 = \text{num\_input} \cdot p \cdot (1-p)$
|
247
|
+
|
248
|
+
2. For smaller numbers, a direct binomial sampling:
|
249
|
+
$$
|
250
|
+
\text{inputs} \sim \text{Binomial}(\text{num\_input}, p)
|
251
|
+
$$
|
252
|
+
|
253
|
+
where $p = \text{freq} \cdot \text{dt}$ in both cases.
|
254
|
+
|
255
|
+
Parameters
|
256
|
+
----------
|
257
|
+
target : Prefetch
|
258
|
+
The variable that is targeted by this input. Should be an instance of
|
259
|
+
:py:class:`brainstate.State` that's prefetched via the target mechanism.
|
260
|
+
indices : Union[np.ndarray, jax.Array]
|
261
|
+
Indices of the target to receive input. If None, input is applied to the entire target.
|
262
|
+
num_input : int
|
263
|
+
The number of independent Poisson input sources.
|
264
|
+
freq : Union[int, float]
|
265
|
+
The firing frequency of each input source in Hz.
|
266
|
+
weight : ndarray, float, or brainunit.Quantity
|
267
|
+
The synaptic weight of each input spike.
|
268
|
+
name : Optional[str], optional
|
269
|
+
The name of this module.
|
270
|
+
|
271
|
+
Examples
|
272
|
+
--------
|
273
|
+
>>> import brainstate as bs
|
274
|
+
>>> import brainunit as u
|
275
|
+
>>> import numpy as np
|
276
|
+
>>>
|
277
|
+
>>> # Create a neuron group with membrane potential
|
278
|
+
>>> neuron = bs.nn.LIF(100)
|
279
|
+
>>> neuron.init_state(batch_size=1)
|
280
|
+
>>>
|
281
|
+
>>> # Add Poisson input to all neurons
|
282
|
+
>>> poisson_in = bs.nn.PoissonInput(
|
283
|
+
... target=neuron.V,
|
284
|
+
... indices=None,
|
285
|
+
... num_input=200,
|
286
|
+
... freq=50 * u.Hz,
|
287
|
+
... weight=0.1 * u.mV
|
288
|
+
... )
|
289
|
+
>>>
|
290
|
+
>>> # Add Poisson input only to specific neurons
|
291
|
+
>>> indices = np.array([0, 10, 20, 30])
|
292
|
+
>>> specific_input = bs.nn.PoissonInput(
|
293
|
+
... target=neuron.V,
|
294
|
+
... indices=indices,
|
295
|
+
... num_input=50,
|
296
|
+
... freq=100 * u.Hz,
|
297
|
+
... weight=0.2 * u.mV
|
298
|
+
... )
|
299
|
+
>>>
|
300
|
+
>>> # Run simulation with the inputs
|
301
|
+
>>> for t in range(100):
|
302
|
+
... poisson_in.update()
|
303
|
+
... specific_input.update()
|
304
|
+
... neuron.update()
|
305
|
+
|
306
|
+
Notes
|
307
|
+
-----
|
308
|
+
- The Poisson inputs are statistically independent between update steps and across
|
309
|
+
target neurons.
|
310
|
+
- This implementation is particularly efficient for large numbers of inputs or targets.
|
311
|
+
- For very sparse connectivity patterns, consider using individual PoissonSpike neurons
|
312
|
+
with specific connectivity patterns instead.
|
313
|
+
- The update method internally calls the poisson_input function which handles the
|
314
|
+
spike generation and target state updates.
|
315
|
+
"""
|
316
|
+
|
317
|
+
def __init__(
|
318
|
+
self,
|
319
|
+
target: Prefetch,
|
320
|
+
indices: Union[np.ndarray, jax.Array],
|
321
|
+
num_input: int,
|
322
|
+
freq: u.Quantity[u.Hz],
|
323
|
+
weight: Union[jax.typing.ArrayLike, u.Quantity],
|
324
|
+
name: Optional[str] = None,
|
325
|
+
):
|
326
|
+
super().__init__(name=name)
|
327
|
+
|
328
|
+
self.target = target
|
329
|
+
self.indices = indices
|
330
|
+
self.num_input = num_input
|
331
|
+
self.freq = freq
|
332
|
+
self.weight = weight
|
333
|
+
|
334
|
+
def update(self):
|
335
|
+
target_state = getattr(self.target.module, self.target.item)
|
336
|
+
|
337
|
+
# generate Poisson input
|
338
|
+
poisson_input(
|
339
|
+
self.freq,
|
340
|
+
self.num_input,
|
341
|
+
self.weight,
|
342
|
+
target_state,
|
343
|
+
self.indices,
|
344
|
+
)
|
345
|
+
|
346
|
+
|
347
|
+
def poisson_input(
|
348
|
+
freq: u.Quantity[u.Hz],
|
349
|
+
num_input: int,
|
350
|
+
weight: Union[jax.typing.ArrayLike, u.Quantity],
|
351
|
+
target: State,
|
352
|
+
indices: Optional[Union[np.ndarray, jax.Array]] = None,
|
353
|
+
refractory: Optional[Union[jax.Array]] = None,
|
354
|
+
):
|
355
|
+
r"""Generates Poisson-distributed input spikes to a target state variable.
|
356
|
+
|
357
|
+
This function simulates Poisson input to a given state, updating the target
|
358
|
+
variable with generated spikes based on the specified frequency, number of inputs,
|
359
|
+
and synaptic weight. The input can be applied to specific indices of the target
|
360
|
+
or to the entire target if indices are not provided.
|
361
|
+
|
362
|
+
The function uses two different methods to generate the Poisson-distributed input:
|
363
|
+
1. For large numbers of inputs (a > 5 and b > 5), a normal approximation is used
|
364
|
+
2. For smaller numbers, a direct binomial sampling approach is used
|
365
|
+
|
366
|
+
Mathematical model for Poisson input:
|
367
|
+
$$
|
368
|
+
P(\text{spike}) = \text{freq} \cdot \text{dt}
|
369
|
+
$$
|
370
|
+
|
371
|
+
For the normal approximation (when a > 5 and b > 5):
|
372
|
+
$$
|
373
|
+
\text{inputs} \sim \mathcal{N}(a, b \cdot p)
|
374
|
+
$$
|
375
|
+
where:
|
376
|
+
$$
|
377
|
+
a = \text{num\_input} \cdot p
|
378
|
+
$$
|
379
|
+
$$
|
380
|
+
b = \text{num\_input} \cdot (1 - p)
|
381
|
+
$$
|
382
|
+
$$
|
383
|
+
p = \text{freq} \cdot \text{dt}
|
384
|
+
$$
|
385
|
+
|
386
|
+
For direct binomial sampling (when a ≤ 5 or b ≤ 5):
|
387
|
+
$$
|
388
|
+
\text{inputs} \sim \text{Binomial}(\text{num\_input}, p)
|
389
|
+
$$
|
390
|
+
|
391
|
+
Parameters
|
392
|
+
----------
|
393
|
+
freq : u.Quantity[u.Hz]
|
394
|
+
The frequency of the Poisson input in Hertz.
|
395
|
+
num_input : int
|
396
|
+
The number of input channels or neurons generating the Poisson spikes.
|
397
|
+
weight : u.Quantity
|
398
|
+
The synaptic weight applied to each spike.
|
399
|
+
target : State
|
400
|
+
The target state variable to which the Poisson input is applied.
|
401
|
+
indices : Optional[Union[np.ndarray, jax.Array]], optional
|
402
|
+
Specific indices of the target to apply the input. If None, the input is applied
|
403
|
+
to the entire target.
|
404
|
+
refractory : Optional[Union[jax.Array]], optional
|
405
|
+
A boolean array indicating which parts of the target are in a refractory state
|
406
|
+
and should not be updated. Should be the same length as the target.
|
407
|
+
|
408
|
+
Examples
|
409
|
+
--------
|
410
|
+
>>> import brainstate as bs
|
411
|
+
>>> import brainunit as u
|
412
|
+
>>> import numpy as np
|
413
|
+
>>>
|
414
|
+
>>> # Create a membrane potential state
|
415
|
+
>>> V = bs.HiddenState(np.zeros(100) * u.mV)
|
416
|
+
>>>
|
417
|
+
>>> # Add Poisson input to all neurons at 50 Hz
|
418
|
+
>>> bs.nn.poisson_input(
|
419
|
+
... freq=50 * u.Hz,
|
420
|
+
... num_input=200,
|
421
|
+
... weight=0.1 * u.mV,
|
422
|
+
... target=V
|
423
|
+
... )
|
424
|
+
>>>
|
425
|
+
>>> # Apply Poisson input only to a subset of neurons
|
426
|
+
>>> indices = np.array([0, 10, 20, 30])
|
427
|
+
>>> bs.nn.poisson_input(
|
428
|
+
... freq=100 * u.Hz,
|
429
|
+
... num_input=50,
|
430
|
+
... weight=0.2 * u.mV,
|
431
|
+
... target=V,
|
432
|
+
... indices=indices
|
433
|
+
... )
|
434
|
+
>>>
|
435
|
+
>>> # Apply input with refractory mask
|
436
|
+
>>> refractory = np.zeros(100, dtype=bool)
|
437
|
+
>>> refractory[40:60] = True # neurons 40-59 are in refractory period
|
438
|
+
>>> bs.nn.poisson_input(
|
439
|
+
... freq=75 * u.Hz,
|
440
|
+
... num_input=100,
|
441
|
+
... weight=0.15 * u.mV,
|
442
|
+
... target=V,
|
443
|
+
... refractory=refractory
|
444
|
+
... )
|
445
|
+
|
446
|
+
Notes
|
447
|
+
-----
|
448
|
+
- The function automatically switches between normal approximation and binomial
|
449
|
+
sampling based on the input parameters to optimize computation efficiency.
|
450
|
+
- For large numbers of inputs, the normal approximation provides significant
|
451
|
+
performance improvements.
|
452
|
+
- The weight parameter is applied uniformly to all generated spikes.
|
453
|
+
- When refractory is provided, the corresponding target elements are not updated.
|
454
|
+
"""
|
455
|
+
freq = maybe_state(freq)
|
456
|
+
weight = maybe_state(weight)
|
457
|
+
|
458
|
+
assert isinstance(target, State), 'The target must be a State.'
|
459
|
+
p = freq * environ.get_dt()
|
460
|
+
p = p.to_decimal() if isinstance(p, u.Quantity) else p
|
461
|
+
a = num_input * p
|
462
|
+
b = num_input * (1 - p)
|
463
|
+
tar_val = target.value
|
464
|
+
cond = u.math.logical_and(a > 5, b > 5)
|
465
|
+
|
466
|
+
if indices is None:
|
467
|
+
# generate Poisson input
|
468
|
+
branch1 = jax.tree.map(
|
469
|
+
lambda tar: random.normal(
|
470
|
+
a,
|
471
|
+
b * p,
|
472
|
+
tar.shape,
|
473
|
+
dtype=tar.dtype
|
474
|
+
),
|
475
|
+
tar_val,
|
476
|
+
is_leaf=u.math.is_quantity
|
477
|
+
)
|
478
|
+
branch2 = jax.tree.map(
|
479
|
+
lambda tar: random.binomial(
|
480
|
+
num_input,
|
481
|
+
p,
|
482
|
+
tar.shape,
|
483
|
+
check_valid=False,
|
484
|
+
dtype=tar.dtype
|
485
|
+
),
|
486
|
+
tar_val,
|
487
|
+
is_leaf=u.math.is_quantity,
|
488
|
+
)
|
489
|
+
|
490
|
+
inp = jax.tree.map(
|
491
|
+
lambda b1, b2: u.math.where(cond, b1, b2),
|
492
|
+
branch1,
|
493
|
+
branch2,
|
494
|
+
is_leaf=u.math.is_quantity,
|
495
|
+
)
|
496
|
+
|
497
|
+
# inp = jax.lax.cond(
|
498
|
+
# cond,
|
499
|
+
# lambda rand_key: jax.tree.map(
|
500
|
+
# lambda tar: random.normal(
|
501
|
+
# a,
|
502
|
+
# b * p,
|
503
|
+
# tar.shape,
|
504
|
+
# key=rand_key,
|
505
|
+
# dtype=tar.dtype
|
506
|
+
# ),
|
507
|
+
# tar_val,
|
508
|
+
# is_leaf=u.math.is_quantity
|
509
|
+
# ),
|
510
|
+
# lambda rand_key: jax.tree.map(
|
511
|
+
# lambda tar: random.binomial(
|
512
|
+
# num_input,
|
513
|
+
# p,
|
514
|
+
# tar.shape,
|
515
|
+
# key=rand_key,
|
516
|
+
# check_valid=False,
|
517
|
+
# dtype=tar.dtype
|
518
|
+
# ),
|
519
|
+
# tar_val,
|
520
|
+
# is_leaf=u.math.is_quantity,
|
521
|
+
# ),
|
522
|
+
# random.split_key()
|
523
|
+
# )
|
524
|
+
|
525
|
+
# update target variable
|
526
|
+
data = jax.tree.map(
|
527
|
+
lambda tar, x: tar + x * weight,
|
528
|
+
target.value,
|
529
|
+
inp,
|
530
|
+
is_leaf=u.math.is_quantity
|
531
|
+
)
|
532
|
+
|
533
|
+
else:
|
534
|
+
# generate Poisson input
|
535
|
+
branch1 = jax.tree.map(
|
536
|
+
lambda tar: random.normal(
|
537
|
+
a,
|
538
|
+
b * p,
|
539
|
+
tar[indices].shape,
|
540
|
+
dtype=tar.dtype
|
541
|
+
),
|
542
|
+
tar_val,
|
543
|
+
is_leaf=u.math.is_quantity
|
544
|
+
)
|
545
|
+
branch2 = jax.tree.map(
|
546
|
+
lambda tar: random.binomial(
|
547
|
+
num_input,
|
548
|
+
p,
|
549
|
+
tar[indices].shape,
|
550
|
+
check_valid=False,
|
551
|
+
dtype=tar.dtype
|
552
|
+
),
|
553
|
+
tar_val,
|
554
|
+
is_leaf=u.math.is_quantity
|
555
|
+
)
|
556
|
+
|
557
|
+
inp = jax.tree.map(
|
558
|
+
lambda b1, b2: u.math.where(cond, b1, b2),
|
559
|
+
branch1,
|
560
|
+
branch2,
|
561
|
+
is_leaf=u.math.is_quantity,
|
562
|
+
)
|
563
|
+
|
564
|
+
# inp = jax.lax.cond(
|
565
|
+
# cond,
|
566
|
+
# lambda rand_key: jax.tree.map(
|
567
|
+
# lambda tar: random.normal(
|
568
|
+
# a,
|
569
|
+
# b * p,
|
570
|
+
# tar[indices].shape,
|
571
|
+
# key=rand_key,
|
572
|
+
# dtype=tar.dtype
|
573
|
+
# ),
|
574
|
+
# tar_val,
|
575
|
+
# is_leaf=u.math.is_quantity
|
576
|
+
# ),
|
577
|
+
# lambda rand_key: jax.tree.map(
|
578
|
+
# lambda tar: random.binomial(
|
579
|
+
# num_input,
|
580
|
+
# p,
|
581
|
+
# tar[indices].shape,
|
582
|
+
# key=rand_key,
|
583
|
+
# check_valid=False,
|
584
|
+
# dtype=tar.dtype
|
585
|
+
# ),
|
586
|
+
# tar_val,
|
587
|
+
# is_leaf=u.math.is_quantity
|
588
|
+
# ),
|
589
|
+
# random.split_key()
|
590
|
+
# )
|
591
|
+
|
592
|
+
# update target variable
|
593
|
+
data = jax.tree.map(
|
594
|
+
lambda x, tar: tar.at[indices].add(x * weight),
|
595
|
+
inp,
|
596
|
+
tar_val,
|
597
|
+
is_leaf=u.math.is_quantity
|
598
|
+
)
|
599
|
+
|
600
|
+
if refractory is not None:
|
601
|
+
target.value = jax.tree.map(
|
602
|
+
lambda x, tar: u.math.where(refractory, tar, x),
|
603
|
+
data,
|
604
|
+
tar_val,
|
605
|
+
is_leaf=u.math.is_quantity
|
606
|
+
)
|
607
|
+
else:
|
608
|
+
target.value = data
|