brainstate 0.1.0.post20250420__py2.py3-none-any.whl → 0.1.0.post20250422__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_compatible_import.py +15 -0
- brainstate/_state.py +5 -4
- brainstate/_state_test.py +2 -1
- brainstate/augment/_autograd_test.py +3 -2
- brainstate/augment/_eval_shape.py +2 -1
- brainstate/augment/_mapping.py +0 -1
- brainstate/augment/_mapping_test.py +1 -0
- brainstate/compile/_ad_checkpoint.py +2 -1
- brainstate/compile/_conditions.py +3 -3
- brainstate/compile/_conditions_test.py +2 -1
- brainstate/compile/_error_if.py +2 -1
- brainstate/compile/_error_if_test.py +2 -1
- brainstate/compile/_jit.py +3 -2
- brainstate/compile/_jit_test.py +2 -1
- brainstate/compile/_loop_collect_return.py +2 -2
- brainstate/compile/_loop_collect_return_test.py +2 -1
- brainstate/compile/_loop_no_collection.py +1 -1
- brainstate/compile/_make_jaxpr.py +2 -2
- brainstate/compile/_make_jaxpr_test.py +2 -1
- brainstate/compile/_progress_bar.py +2 -1
- brainstate/compile/_unvmap.py +1 -2
- brainstate/environ.py +4 -4
- brainstate/environ_test.py +2 -1
- brainstate/functional/_activations.py +2 -1
- brainstate/functional/_activations_test.py +1 -1
- brainstate/functional/_normalization.py +2 -1
- brainstate/functional/_others.py +2 -1
- brainstate/graph/_graph_operation.py +3 -2
- brainstate/graph/_graph_operation_test.py +4 -3
- brainstate/init/_base.py +2 -1
- brainstate/init/_generic.py +2 -1
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_collective_ops.py +1 -0
- brainstate/nn/_collective_ops_test.py +0 -4
- brainstate/nn/_common.py +0 -1
- brainstate/nn/_dyn_impl/__init__.py +0 -4
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +431 -13
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +2 -1
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +405 -103
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +2 -1
- brainstate/nn/_dyn_impl/_inputs.py +236 -29
- brainstate/nn/_dyn_impl/_rate_rnns.py +238 -82
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +2 -1
- brainstate/nn/_dyn_impl/_readout.py +91 -8
- brainstate/nn/_dyn_impl/_readout_test.py +2 -1
- brainstate/nn/_dynamics/_dynamics_base.py +676 -96
- brainstate/nn/_dynamics/_dynamics_base_test.py +2 -1
- brainstate/nn/_dynamics/_projection_base.py +29 -30
- brainstate/nn/_dynamics/_state_delay.py +3 -3
- brainstate/nn/_dynamics/_synouts_test.py +2 -1
- brainstate/nn/_elementwise/_dropout.py +3 -2
- brainstate/nn/_elementwise/_dropout_test.py +2 -1
- brainstate/nn/_elementwise/_elementwise.py +2 -1
- brainstate/nn/{_dyn_impl/_projection_alignpost.py → _event/__init__.py} +8 -7
- brainstate/nn/_event/_fixedprob_mv.py +169 -0
- brainstate/nn/_event/_fixedprob_mv_test.py +115 -0
- brainstate/nn/_event/_linear_mv.py +85 -0
- brainstate/nn/_event/_linear_mv_test.py +121 -0
- brainstate/nn/_exp_euler.py +2 -1
- brainstate/nn/_exp_euler_test.py +2 -1
- brainstate/nn/_interaction/_conv.py +2 -1
- brainstate/nn/_interaction/_linear.py +2 -1
- brainstate/nn/_interaction/_linear_test.py +2 -1
- brainstate/nn/_interaction/_normalizations.py +3 -2
- brainstate/nn/_interaction/_poolings.py +4 -3
- brainstate/nn/_module_test.py +2 -1
- brainstate/nn/metrics.py +4 -3
- brainstate/optim/_lr_scheduler.py +2 -1
- brainstate/optim/_lr_scheduler_test.py +2 -1
- brainstate/optim/_optax_optimizer_test.py +2 -1
- brainstate/optim/_sgd_optimizer.py +3 -2
- brainstate/random/_rand_funs.py +2 -1
- brainstate/random/_rand_funs_test.py +3 -2
- brainstate/random/_rand_seed.py +3 -2
- brainstate/random/_rand_seed_test.py +2 -1
- brainstate/random/_rand_state.py +4 -3
- brainstate/surrogate.py +1 -2
- brainstate/typing.py +4 -4
- brainstate/util/_caller.py +2 -1
- brainstate/util/_others.py +4 -4
- brainstate/util/_pretty_pytree.py +1 -1
- brainstate/util/_pretty_pytree_test.py +2 -1
- brainstate/util/_pretty_table.py +43 -43
- brainstate/util/_struct.py +2 -1
- brainstate/util/filter.py +0 -1
- {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/METADATA +3 -3
- brainstate-0.1.0.post20250422.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250420.dist-info/RECORD +0 -129
- {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/top_level.txt +0 -0
@@ -14,10 +14,11 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
+
from typing import Union, Optional, Sequence, Callable
|
18
|
+
|
17
19
|
import brainunit as u
|
18
20
|
import jax
|
19
21
|
import numpy as np
|
20
|
-
from typing import Union, Optional, Sequence, Callable
|
21
22
|
|
22
23
|
from brainstate import environ, init, random
|
23
24
|
from brainstate._state import ShortTermState
|
@@ -140,8 +141,68 @@ class PoissonSpike(Dynamics):
|
|
140
141
|
|
141
142
|
|
142
143
|
class PoissonEncoder(Dynamics):
|
143
|
-
"""
|
144
|
-
|
144
|
+
r"""Poisson spike encoder for converting firing rates to spike trains.
|
145
|
+
|
146
|
+
This class implements a Poisson process to generate spikes based on provided
|
147
|
+
firing rates. Unlike the PoissonSpike class, this encoder accepts firing rates
|
148
|
+
as input during the update step rather than having them fixed at initialization.
|
149
|
+
|
150
|
+
The spike generation follows a Poisson process where the probability of a spike
|
151
|
+
in each time step is proportional to the firing rate and the simulation time step:
|
152
|
+
|
153
|
+
$$
|
154
|
+
P(\text{spike}) = \text{rate} \cdot \text{dt}
|
155
|
+
$$
|
156
|
+
|
157
|
+
For each neuron and time step, the encoder draws a random number from a uniform
|
158
|
+
distribution [0,1] and generates a spike if the number is less than or equal to
|
159
|
+
the spiking probability.
|
160
|
+
|
161
|
+
Parameters
|
162
|
+
----------
|
163
|
+
in_size : Size
|
164
|
+
Size of the input to the encoder, defining the shape of the output spike train.
|
165
|
+
spk_type : DTypeLike, default=bool
|
166
|
+
Data type for the generated spikes. Typically boolean for binary spikes.
|
167
|
+
name : str, optional
|
168
|
+
Name of the encoder module.
|
169
|
+
|
170
|
+
Examples
|
171
|
+
--------
|
172
|
+
>>> import brainstate as bs
|
173
|
+
>>> import brainunit as u
|
174
|
+
>>> import numpy as np
|
175
|
+
>>>
|
176
|
+
>>> # Create a Poisson encoder for 10 neurons
|
177
|
+
>>> encoder = bs.nn.PoissonEncoder(10)
|
178
|
+
>>>
|
179
|
+
>>> # Generate spikes with varying firing rates
|
180
|
+
>>> rates = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100]) * u.Hz
|
181
|
+
>>> spikes = encoder.update(rates)
|
182
|
+
>>>
|
183
|
+
>>> # Use in a more complex processing pipeline
|
184
|
+
>>> # First, generate rate-coded output from an analog signal
|
185
|
+
>>> analog_values = np.random.rand(10) * 100 # values between 0 and 100
|
186
|
+
>>> firing_rates = analog_values * u.Hz # convert to firing rates
|
187
|
+
>>> spike_train = encoder.update(firing_rates)
|
188
|
+
>>>
|
189
|
+
>>> # Feed the spikes into a spiking neural network
|
190
|
+
>>> neuron_layer = bs.nn.LIF(10)
|
191
|
+
>>> neuron_layer.init_state(batch_size=1)
|
192
|
+
>>> output_spikes = neuron_layer.update(spike_train)
|
193
|
+
|
194
|
+
Notes
|
195
|
+
-----
|
196
|
+
- This encoder is particularly useful for rate-to-spike conversion in neuromorphic
|
197
|
+
computing applications and sensory encoding tasks.
|
198
|
+
- The statistical properties of the generated spike trains follow a Poisson process,
|
199
|
+
where the inter-spike intervals are exponentially distributed.
|
200
|
+
- For small time steps (dt), the number of spikes in a time window T approximately
|
201
|
+
follows a Poisson distribution with parameter λ = rate * T.
|
202
|
+
- Unlike PoissonSpike which has fixed rates, this encoder allows dynamic rate changes
|
203
|
+
with every update call, making it suitable for encoding time-varying signals.
|
204
|
+
- The independence of spike generation between time steps results in renewal process
|
205
|
+
statistics without memory of previous spiking history.
|
145
206
|
"""
|
146
207
|
|
147
208
|
def __init__(
|
@@ -160,23 +221,99 @@ class PoissonEncoder(Dynamics):
|
|
160
221
|
|
161
222
|
|
162
223
|
class PoissonInput(Module):
|
163
|
-
"""
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
All
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
224
|
+
r"""Poisson Input to the given state variable.
|
225
|
+
|
226
|
+
This class provides a way to add independent Poisson-distributed spiking input
|
227
|
+
to a target state variable. For large numbers of inputs, this implementation is
|
228
|
+
computationally more efficient than creating separate Poisson spike generators.
|
229
|
+
|
230
|
+
The synaptic events are generated randomly during simulation runtime and are not
|
231
|
+
preloaded or stored in memory, which improves memory efficiency for large-scale
|
232
|
+
simulations. All inputs target the same variable with the same frequency and
|
233
|
+
synaptic weight.
|
234
|
+
|
235
|
+
The Poisson process generates spikes with probability based on the frequency and
|
236
|
+
simulation time step:
|
237
|
+
|
238
|
+
$$
|
239
|
+
P(\text{spike}) = \text{freq} \cdot \text{dt}
|
240
|
+
$$
|
241
|
+
|
242
|
+
For computational efficiency, two different methods are used for spike generation:
|
243
|
+
|
244
|
+
1. For large numbers of inputs, a normal approximation:
|
245
|
+
$$
|
246
|
+
\text{inputs} \sim \mathcal{N}(\mu, \sigma^2)
|
247
|
+
$$
|
248
|
+
where $\mu = \text{num\_input} \cdot p$ and $\sigma^2 = \text{num\_input} \cdot p \cdot (1-p)$
|
249
|
+
|
250
|
+
2. For smaller numbers, a direct binomial sampling:
|
251
|
+
$$
|
252
|
+
\text{inputs} \sim \text{Binomial}(\text{num\_input}, p)
|
253
|
+
$$
|
254
|
+
|
255
|
+
where $p = \text{freq} \cdot \text{dt}$ in both cases.
|
256
|
+
|
257
|
+
Parameters
|
258
|
+
----------
|
259
|
+
target : Prefetch
|
260
|
+
The variable that is targeted by this input. Should be an instance of
|
261
|
+
:py:class:`brainstate.State` that's prefetched via the target mechanism.
|
262
|
+
indices : Union[np.ndarray, jax.Array]
|
263
|
+
Indices of the target to receive input. If None, input is applied to the entire target.
|
264
|
+
num_input : int
|
265
|
+
The number of independent Poisson input sources.
|
266
|
+
freq : Union[int, float]
|
267
|
+
The firing frequency of each input source in Hz.
|
268
|
+
weight : ndarray, float, or brainunit.Quantity
|
269
|
+
The synaptic weight of each input spike.
|
270
|
+
name : Optional[str], optional
|
271
|
+
The name of this module.
|
272
|
+
|
273
|
+
Examples
|
274
|
+
--------
|
275
|
+
>>> import brainstate as bs
|
276
|
+
>>> import brainunit as u
|
277
|
+
>>> import numpy as np
|
278
|
+
>>>
|
279
|
+
>>> # Create a neuron group with membrane potential
|
280
|
+
>>> neuron = bs.nn.LIF(100)
|
281
|
+
>>> neuron.init_state(batch_size=1)
|
282
|
+
>>>
|
283
|
+
>>> # Add Poisson input to all neurons
|
284
|
+
>>> poisson_in = bs.nn.PoissonInput(
|
285
|
+
... target=neuron.V,
|
286
|
+
... indices=None,
|
287
|
+
... num_input=200,
|
288
|
+
... freq=50 * u.Hz,
|
289
|
+
... weight=0.1 * u.mV
|
290
|
+
... )
|
291
|
+
>>>
|
292
|
+
>>> # Add Poisson input only to specific neurons
|
293
|
+
>>> indices = np.array([0, 10, 20, 30])
|
294
|
+
>>> specific_input = bs.nn.PoissonInput(
|
295
|
+
... target=neuron.V,
|
296
|
+
... indices=indices,
|
297
|
+
... num_input=50,
|
298
|
+
... freq=100 * u.Hz,
|
299
|
+
... weight=0.2 * u.mV
|
300
|
+
... )
|
301
|
+
>>>
|
302
|
+
>>> # Run simulation with the inputs
|
303
|
+
>>> for t in range(100):
|
304
|
+
... poisson_in.update()
|
305
|
+
... specific_input.update()
|
306
|
+
... neuron.update()
|
307
|
+
|
308
|
+
Notes
|
309
|
+
-----
|
310
|
+
- The Poisson inputs are statistically independent between update steps and across
|
311
|
+
target neurons.
|
312
|
+
- This implementation is particularly efficient for large numbers of inputs or targets.
|
313
|
+
- For very sparse connectivity patterns, consider using individual PoissonSpike neurons
|
314
|
+
with specific connectivity patterns instead.
|
315
|
+
- The update method internally calls the poisson_input function which handles the
|
316
|
+
spike generation and target state updates.
|
180
317
|
"""
|
181
318
|
|
182
319
|
def __init__(
|
@@ -184,8 +321,8 @@ class PoissonInput(Module):
|
|
184
321
|
target: Prefetch,
|
185
322
|
indices: Union[np.ndarray, jax.Array],
|
186
323
|
num_input: int,
|
187
|
-
freq:
|
188
|
-
weight: Union[
|
324
|
+
freq: u.Quantity[u.Hz],
|
325
|
+
weight: Union[jax.typing.ArrayLike, u.Quantity],
|
189
326
|
name: Optional[str] = None,
|
190
327
|
):
|
191
328
|
super().__init__(name=name)
|
@@ -212,19 +349,47 @@ class PoissonInput(Module):
|
|
212
349
|
def poisson_input(
|
213
350
|
freq: u.Quantity[u.Hz],
|
214
351
|
num_input: int,
|
215
|
-
weight: u.Quantity,
|
352
|
+
weight: Union[jax.typing.ArrayLike, u.Quantity],
|
216
353
|
target: State,
|
217
354
|
indices: Optional[Union[np.ndarray, jax.Array]] = None,
|
218
355
|
refractory: Optional[Union[jax.Array]] = None,
|
219
356
|
):
|
220
|
-
"""
|
221
|
-
Generates Poisson-distributed input spikes to a target state variable.
|
357
|
+
r"""Generates Poisson-distributed input spikes to a target state variable.
|
222
358
|
|
223
359
|
This function simulates Poisson input to a given state, updating the target
|
224
360
|
variable with generated spikes based on the specified frequency, number of inputs,
|
225
361
|
and synaptic weight. The input can be applied to specific indices of the target
|
226
362
|
or to the entire target if indices are not provided.
|
227
363
|
|
364
|
+
The function uses two different methods to generate the Poisson-distributed input:
|
365
|
+
1. For large numbers of inputs (a > 5 and b > 5), a normal approximation is used
|
366
|
+
2. For smaller numbers, a direct binomial sampling approach is used
|
367
|
+
|
368
|
+
Mathematical model for Poisson input:
|
369
|
+
$$
|
370
|
+
P(\text{spike}) = \text{freq} \cdot \text{dt}
|
371
|
+
$$
|
372
|
+
|
373
|
+
For the normal approximation (when a > 5 and b > 5):
|
374
|
+
$$
|
375
|
+
\text{inputs} \sim \mathcal{N}(a, b \cdot p)
|
376
|
+
$$
|
377
|
+
where:
|
378
|
+
$$
|
379
|
+
a = \text{num\_input} \cdot p
|
380
|
+
$$
|
381
|
+
$$
|
382
|
+
b = \text{num\_input} \cdot (1 - p)
|
383
|
+
$$
|
384
|
+
$$
|
385
|
+
p = \text{freq} \cdot \text{dt}
|
386
|
+
$$
|
387
|
+
|
388
|
+
For direct binomial sampling (when a ≤ 5 or b ≤ 5):
|
389
|
+
$$
|
390
|
+
\text{inputs} \sim \text{Binomial}(\text{num\_input}, p)
|
391
|
+
$$
|
392
|
+
|
228
393
|
Parameters
|
229
394
|
----------
|
230
395
|
freq : u.Quantity[u.Hz]
|
@@ -242,10 +407,52 @@ def poisson_input(
|
|
242
407
|
A boolean array indicating which parts of the target are in a refractory state
|
243
408
|
and should not be updated. Should be the same length as the target.
|
244
409
|
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
410
|
+
Examples
|
411
|
+
--------
|
412
|
+
>>> import brainstate as bs
|
413
|
+
>>> import brainunit as u
|
414
|
+
>>> import numpy as np
|
415
|
+
>>>
|
416
|
+
>>> # Create a membrane potential state
|
417
|
+
>>> V = bs.HiddenState(np.zeros(100) * u.mV)
|
418
|
+
>>>
|
419
|
+
>>> # Add Poisson input to all neurons at 50 Hz
|
420
|
+
>>> bs.nn.poisson_input(
|
421
|
+
... freq=50 * u.Hz,
|
422
|
+
... num_input=200,
|
423
|
+
... weight=0.1 * u.mV,
|
424
|
+
... target=V
|
425
|
+
... )
|
426
|
+
>>>
|
427
|
+
>>> # Apply Poisson input only to a subset of neurons
|
428
|
+
>>> indices = np.array([0, 10, 20, 30])
|
429
|
+
>>> bs.nn.poisson_input(
|
430
|
+
... freq=100 * u.Hz,
|
431
|
+
... num_input=50,
|
432
|
+
... weight=0.2 * u.mV,
|
433
|
+
... target=V,
|
434
|
+
... indices=indices
|
435
|
+
... )
|
436
|
+
>>>
|
437
|
+
>>> # Apply input with refractory mask
|
438
|
+
>>> refractory = np.zeros(100, dtype=bool)
|
439
|
+
>>> refractory[40:60] = True # neurons 40-59 are in refractory period
|
440
|
+
>>> bs.nn.poisson_input(
|
441
|
+
... freq=75 * u.Hz,
|
442
|
+
... num_input=100,
|
443
|
+
... weight=0.15 * u.mV,
|
444
|
+
... target=V,
|
445
|
+
... refractory=refractory
|
446
|
+
... )
|
447
|
+
|
448
|
+
Notes
|
449
|
+
-----
|
450
|
+
- The function automatically switches between normal approximation and binomial
|
451
|
+
sampling based on the input parameters to optimize computation efficiency.
|
452
|
+
- For large numbers of inputs, the normal approximation provides significant
|
453
|
+
performance improvements.
|
454
|
+
- The weight parameter is applied uniformly to all generated spikes.
|
455
|
+
- When refractory is provided, the corresponding target elements are not updated.
|
249
456
|
"""
|
250
457
|
freq = maybe_state(freq)
|
251
458
|
weight = maybe_state(weight)
|