brainstate 0.1.0.post20250420__py2.py3-none-any.whl → 0.1.0.post20250422__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. brainstate/_compatible_import.py +15 -0
  2. brainstate/_state.py +5 -4
  3. brainstate/_state_test.py +2 -1
  4. brainstate/augment/_autograd_test.py +3 -2
  5. brainstate/augment/_eval_shape.py +2 -1
  6. brainstate/augment/_mapping.py +0 -1
  7. brainstate/augment/_mapping_test.py +1 -0
  8. brainstate/compile/_ad_checkpoint.py +2 -1
  9. brainstate/compile/_conditions.py +3 -3
  10. brainstate/compile/_conditions_test.py +2 -1
  11. brainstate/compile/_error_if.py +2 -1
  12. brainstate/compile/_error_if_test.py +2 -1
  13. brainstate/compile/_jit.py +3 -2
  14. brainstate/compile/_jit_test.py +2 -1
  15. brainstate/compile/_loop_collect_return.py +2 -2
  16. brainstate/compile/_loop_collect_return_test.py +2 -1
  17. brainstate/compile/_loop_no_collection.py +1 -1
  18. brainstate/compile/_make_jaxpr.py +2 -2
  19. brainstate/compile/_make_jaxpr_test.py +2 -1
  20. brainstate/compile/_progress_bar.py +2 -1
  21. brainstate/compile/_unvmap.py +1 -2
  22. brainstate/environ.py +4 -4
  23. brainstate/environ_test.py +2 -1
  24. brainstate/functional/_activations.py +2 -1
  25. brainstate/functional/_activations_test.py +1 -1
  26. brainstate/functional/_normalization.py +2 -1
  27. brainstate/functional/_others.py +2 -1
  28. brainstate/graph/_graph_operation.py +3 -2
  29. brainstate/graph/_graph_operation_test.py +4 -3
  30. brainstate/init/_base.py +2 -1
  31. brainstate/init/_generic.py +2 -1
  32. brainstate/nn/__init__.py +4 -0
  33. brainstate/nn/_collective_ops.py +1 -0
  34. brainstate/nn/_collective_ops_test.py +0 -4
  35. brainstate/nn/_common.py +0 -1
  36. brainstate/nn/_dyn_impl/__init__.py +0 -4
  37. brainstate/nn/_dyn_impl/_dynamics_neuron.py +431 -13
  38. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +2 -1
  39. brainstate/nn/_dyn_impl/_dynamics_synapse.py +405 -103
  40. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +2 -1
  41. brainstate/nn/_dyn_impl/_inputs.py +236 -29
  42. brainstate/nn/_dyn_impl/_rate_rnns.py +238 -82
  43. brainstate/nn/_dyn_impl/_rate_rnns_test.py +2 -1
  44. brainstate/nn/_dyn_impl/_readout.py +91 -8
  45. brainstate/nn/_dyn_impl/_readout_test.py +2 -1
  46. brainstate/nn/_dynamics/_dynamics_base.py +676 -96
  47. brainstate/nn/_dynamics/_dynamics_base_test.py +2 -1
  48. brainstate/nn/_dynamics/_projection_base.py +29 -30
  49. brainstate/nn/_dynamics/_state_delay.py +3 -3
  50. brainstate/nn/_dynamics/_synouts_test.py +2 -1
  51. brainstate/nn/_elementwise/_dropout.py +3 -2
  52. brainstate/nn/_elementwise/_dropout_test.py +2 -1
  53. brainstate/nn/_elementwise/_elementwise.py +2 -1
  54. brainstate/nn/{_dyn_impl/_projection_alignpost.py → _event/__init__.py} +8 -7
  55. brainstate/nn/_event/_fixedprob_mv.py +169 -0
  56. brainstate/nn/_event/_fixedprob_mv_test.py +115 -0
  57. brainstate/nn/_event/_linear_mv.py +85 -0
  58. brainstate/nn/_event/_linear_mv_test.py +121 -0
  59. brainstate/nn/_exp_euler.py +2 -1
  60. brainstate/nn/_exp_euler_test.py +2 -1
  61. brainstate/nn/_interaction/_conv.py +2 -1
  62. brainstate/nn/_interaction/_linear.py +2 -1
  63. brainstate/nn/_interaction/_linear_test.py +2 -1
  64. brainstate/nn/_interaction/_normalizations.py +3 -2
  65. brainstate/nn/_interaction/_poolings.py +4 -3
  66. brainstate/nn/_module_test.py +2 -1
  67. brainstate/nn/metrics.py +4 -3
  68. brainstate/optim/_lr_scheduler.py +2 -1
  69. brainstate/optim/_lr_scheduler_test.py +2 -1
  70. brainstate/optim/_optax_optimizer_test.py +2 -1
  71. brainstate/optim/_sgd_optimizer.py +3 -2
  72. brainstate/random/_rand_funs.py +2 -1
  73. brainstate/random/_rand_funs_test.py +3 -2
  74. brainstate/random/_rand_seed.py +3 -2
  75. brainstate/random/_rand_seed_test.py +2 -1
  76. brainstate/random/_rand_state.py +4 -3
  77. brainstate/surrogate.py +1 -2
  78. brainstate/typing.py +4 -4
  79. brainstate/util/_caller.py +2 -1
  80. brainstate/util/_others.py +4 -4
  81. brainstate/util/_pretty_pytree.py +1 -1
  82. brainstate/util/_pretty_pytree_test.py +2 -1
  83. brainstate/util/_pretty_table.py +43 -43
  84. brainstate/util/_struct.py +2 -1
  85. brainstate/util/filter.py +0 -1
  86. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/METADATA +3 -3
  87. brainstate-0.1.0.post20250422.dist-info/RECORD +133 -0
  88. brainstate-0.1.0.post20250420.dist-info/RECORD +0 -129
  89. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/LICENSE +0 -0
  90. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/WHEEL +0 -0
  91. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/top_level.txt +0 -0
@@ -14,10 +14,11 @@
14
14
  # ==============================================================================
15
15
  from __future__ import annotations
16
16
 
17
+ from typing import Union, Optional, Sequence, Callable
18
+
17
19
  import brainunit as u
18
20
  import jax
19
21
  import numpy as np
20
- from typing import Union, Optional, Sequence, Callable
21
22
 
22
23
  from brainstate import environ, init, random
23
24
  from brainstate._state import ShortTermState
@@ -140,8 +141,68 @@ class PoissonSpike(Dynamics):
140
141
 
141
142
 
142
143
  class PoissonEncoder(Dynamics):
143
- """
144
- Poisson Neuron Group.
144
+ r"""Poisson spike encoder for converting firing rates to spike trains.
145
+
146
+ This class implements a Poisson process to generate spikes based on provided
147
+ firing rates. Unlike the PoissonSpike class, this encoder accepts firing rates
148
+ as input during the update step rather than having them fixed at initialization.
149
+
150
+ The spike generation follows a Poisson process where the probability of a spike
151
+ in each time step is proportional to the firing rate and the simulation time step:
152
+
153
+ $$
154
+ P(\text{spike}) = \text{rate} \cdot \text{dt}
155
+ $$
156
+
157
+ For each neuron and time step, the encoder draws a random number from a uniform
158
+ distribution [0,1] and generates a spike if the number is less than or equal to
159
+ the spiking probability.
160
+
161
+ Parameters
162
+ ----------
163
+ in_size : Size
164
+ Size of the input to the encoder, defining the shape of the output spike train.
165
+ spk_type : DTypeLike, default=bool
166
+ Data type for the generated spikes. Typically boolean for binary spikes.
167
+ name : str, optional
168
+ Name of the encoder module.
169
+
170
+ Examples
171
+ --------
172
+ >>> import brainstate as bs
173
+ >>> import brainunit as u
174
+ >>> import numpy as np
175
+ >>>
176
+ >>> # Create a Poisson encoder for 10 neurons
177
+ >>> encoder = bs.nn.PoissonEncoder(10)
178
+ >>>
179
+ >>> # Generate spikes with varying firing rates
180
+ >>> rates = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100]) * u.Hz
181
+ >>> spikes = encoder.update(rates)
182
+ >>>
183
+ >>> # Use in a more complex processing pipeline
184
+ >>> # First, generate rate-coded output from an analog signal
185
+ >>> analog_values = np.random.rand(10) * 100 # values between 0 and 100
186
+ >>> firing_rates = analog_values * u.Hz # convert to firing rates
187
+ >>> spike_train = encoder.update(firing_rates)
188
+ >>>
189
+ >>> # Feed the spikes into a spiking neural network
190
+ >>> neuron_layer = bs.nn.LIF(10)
191
+ >>> neuron_layer.init_state(batch_size=1)
192
+ >>> output_spikes = neuron_layer.update(spike_train)
193
+
194
+ Notes
195
+ -----
196
+ - This encoder is particularly useful for rate-to-spike conversion in neuromorphic
197
+ computing applications and sensory encoding tasks.
198
+ - The statistical properties of the generated spike trains follow a Poisson process,
199
+ where the inter-spike intervals are exponentially distributed.
200
+ - For small time steps (dt), the number of spikes in a time window T approximately
201
+ follows a Poisson distribution with parameter λ = rate * T.
202
+ - Unlike PoissonSpike which has fixed rates, this encoder allows dynamic rate changes
203
+ with every update call, making it suitable for encoding time-varying signals.
204
+ - The independence of spike generation between time steps results in renewal process
205
+ statistics without memory of previous spiking history.
145
206
  """
146
207
 
147
208
  def __init__(
@@ -160,23 +221,99 @@ class PoissonEncoder(Dynamics):
160
221
 
161
222
 
162
223
  class PoissonInput(Module):
163
- """
164
- Poisson Input to the given :py:class:`brainstate.State`.
165
-
166
- Adds independent Poisson input to a target variable. For large
167
- numbers of inputs, this is much more efficient than creating a
168
- `PoissonGroup`. The synaptic events are generated randomly during the
169
- simulation and are not preloaded and stored in memory. All the inputs must
170
- target the same variable, have the same frequency and same synaptic weight.
171
- All neurons in the target variable receive independent realizations of
172
- Poisson spike trains.
173
-
174
- Args:
175
- target: The variable that is targeted by this input. Should be an instance of :py:class:`~.Variable`.
176
- num_input: The number of inputs.
177
- freq: The frequency of each of the inputs. Must be a scalar.
178
- weight: The synaptic weight. Must be a scalar.
179
- name: The target name.
224
+ r"""Poisson Input to the given state variable.
225
+
226
+ This class provides a way to add independent Poisson-distributed spiking input
227
+ to a target state variable. For large numbers of inputs, this implementation is
228
+ computationally more efficient than creating separate Poisson spike generators.
229
+
230
+ The synaptic events are generated randomly during simulation runtime and are not
231
+ preloaded or stored in memory, which improves memory efficiency for large-scale
232
+ simulations. All inputs target the same variable with the same frequency and
233
+ synaptic weight.
234
+
235
+ The Poisson process generates spikes with probability based on the frequency and
236
+ simulation time step:
237
+
238
+ $$
239
+ P(\text{spike}) = \text{freq} \cdot \text{dt}
240
+ $$
241
+
242
+ For computational efficiency, two different methods are used for spike generation:
243
+
244
+ 1. For large numbers of inputs, a normal approximation:
245
+ $$
246
+ \text{inputs} \sim \mathcal{N}(\mu, \sigma^2)
247
+ $$
248
+ where $\mu = \text{num\_input} \cdot p$ and $\sigma^2 = \text{num\_input} \cdot p \cdot (1-p)$
249
+
250
+ 2. For smaller numbers, a direct binomial sampling:
251
+ $$
252
+ \text{inputs} \sim \text{Binomial}(\text{num\_input}, p)
253
+ $$
254
+
255
+ where $p = \text{freq} \cdot \text{dt}$ in both cases.
256
+
257
+ Parameters
258
+ ----------
259
+ target : Prefetch
260
+ The variable that is targeted by this input. Should be an instance of
261
+ :py:class:`brainstate.State` that's prefetched via the target mechanism.
262
+ indices : Union[np.ndarray, jax.Array]
263
+ Indices of the target to receive input. If None, input is applied to the entire target.
264
+ num_input : int
265
+ The number of independent Poisson input sources.
266
+ freq : Union[int, float]
267
+ The firing frequency of each input source in Hz.
268
+ weight : ndarray, float, or brainunit.Quantity
269
+ The synaptic weight of each input spike.
270
+ name : Optional[str], optional
271
+ The name of this module.
272
+
273
+ Examples
274
+ --------
275
+ >>> import brainstate as bs
276
+ >>> import brainunit as u
277
+ >>> import numpy as np
278
+ >>>
279
+ >>> # Create a neuron group with membrane potential
280
+ >>> neuron = bs.nn.LIF(100)
281
+ >>> neuron.init_state(batch_size=1)
282
+ >>>
283
+ >>> # Add Poisson input to all neurons
284
+ >>> poisson_in = bs.nn.PoissonInput(
285
+ ... target=neuron.V,
286
+ ... indices=None,
287
+ ... num_input=200,
288
+ ... freq=50 * u.Hz,
289
+ ... weight=0.1 * u.mV
290
+ ... )
291
+ >>>
292
+ >>> # Add Poisson input only to specific neurons
293
+ >>> indices = np.array([0, 10, 20, 30])
294
+ >>> specific_input = bs.nn.PoissonInput(
295
+ ... target=neuron.V,
296
+ ... indices=indices,
297
+ ... num_input=50,
298
+ ... freq=100 * u.Hz,
299
+ ... weight=0.2 * u.mV
300
+ ... )
301
+ >>>
302
+ >>> # Run simulation with the inputs
303
+ >>> for t in range(100):
304
+ ... poisson_in.update()
305
+ ... specific_input.update()
306
+ ... neuron.update()
307
+
308
+ Notes
309
+ -----
310
+ - The Poisson inputs are statistically independent between update steps and across
311
+ target neurons.
312
+ - This implementation is particularly efficient for large numbers of inputs or targets.
313
+ - For very sparse connectivity patterns, consider using individual PoissonSpike neurons
314
+ with specific connectivity patterns instead.
315
+ - The update method internally calls the poisson_input function which handles the
316
+ spike generation and target state updates.
180
317
  """
181
318
 
182
319
  def __init__(
@@ -184,8 +321,8 @@ class PoissonInput(Module):
184
321
  target: Prefetch,
185
322
  indices: Union[np.ndarray, jax.Array],
186
323
  num_input: int,
187
- freq: Union[int, float],
188
- weight: Union[int, float],
324
+ freq: u.Quantity[u.Hz],
325
+ weight: Union[jax.typing.ArrayLike, u.Quantity],
189
326
  name: Optional[str] = None,
190
327
  ):
191
328
  super().__init__(name=name)
@@ -212,19 +349,47 @@ class PoissonInput(Module):
212
349
  def poisson_input(
213
350
  freq: u.Quantity[u.Hz],
214
351
  num_input: int,
215
- weight: u.Quantity,
352
+ weight: Union[jax.typing.ArrayLike, u.Quantity],
216
353
  target: State,
217
354
  indices: Optional[Union[np.ndarray, jax.Array]] = None,
218
355
  refractory: Optional[Union[jax.Array]] = None,
219
356
  ):
220
- """
221
- Generates Poisson-distributed input spikes to a target state variable.
357
+ r"""Generates Poisson-distributed input spikes to a target state variable.
222
358
 
223
359
  This function simulates Poisson input to a given state, updating the target
224
360
  variable with generated spikes based on the specified frequency, number of inputs,
225
361
  and synaptic weight. The input can be applied to specific indices of the target
226
362
  or to the entire target if indices are not provided.
227
363
 
364
+ The function uses two different methods to generate the Poisson-distributed input:
365
+ 1. For large numbers of inputs (a > 5 and b > 5), a normal approximation is used
366
+ 2. For smaller numbers, a direct binomial sampling approach is used
367
+
368
+ Mathematical model for Poisson input:
369
+ $$
370
+ P(\text{spike}) = \text{freq} \cdot \text{dt}
371
+ $$
372
+
373
+ For the normal approximation (when a > 5 and b > 5):
374
+ $$
375
+ \text{inputs} \sim \mathcal{N}(a, b \cdot p)
376
+ $$
377
+ where:
378
+ $$
379
+ a = \text{num\_input} \cdot p
380
+ $$
381
+ $$
382
+ b = \text{num\_input} \cdot (1 - p)
383
+ $$
384
+ $$
385
+ p = \text{freq} \cdot \text{dt}
386
+ $$
387
+
388
+ For direct binomial sampling (when a ≤ 5 or b ≤ 5):
389
+ $$
390
+ \text{inputs} \sim \text{Binomial}(\text{num\_input}, p)
391
+ $$
392
+
228
393
  Parameters
229
394
  ----------
230
395
  freq : u.Quantity[u.Hz]
@@ -242,10 +407,52 @@ def poisson_input(
242
407
  A boolean array indicating which parts of the target are in a refractory state
243
408
  and should not be updated. Should be the same length as the target.
244
409
 
245
- Returns
246
- -------
247
- None
248
- The function updates the target state in place with the generated Poisson input.
410
+ Examples
411
+ --------
412
+ >>> import brainstate as bs
413
+ >>> import brainunit as u
414
+ >>> import numpy as np
415
+ >>>
416
+ >>> # Create a membrane potential state
417
+ >>> V = bs.HiddenState(np.zeros(100) * u.mV)
418
+ >>>
419
+ >>> # Add Poisson input to all neurons at 50 Hz
420
+ >>> bs.nn.poisson_input(
421
+ ... freq=50 * u.Hz,
422
+ ... num_input=200,
423
+ ... weight=0.1 * u.mV,
424
+ ... target=V
425
+ ... )
426
+ >>>
427
+ >>> # Apply Poisson input only to a subset of neurons
428
+ >>> indices = np.array([0, 10, 20, 30])
429
+ >>> bs.nn.poisson_input(
430
+ ... freq=100 * u.Hz,
431
+ ... num_input=50,
432
+ ... weight=0.2 * u.mV,
433
+ ... target=V,
434
+ ... indices=indices
435
+ ... )
436
+ >>>
437
+ >>> # Apply input with refractory mask
438
+ >>> refractory = np.zeros(100, dtype=bool)
439
+ >>> refractory[40:60] = True # neurons 40-59 are in refractory period
440
+ >>> bs.nn.poisson_input(
441
+ ... freq=75 * u.Hz,
442
+ ... num_input=100,
443
+ ... weight=0.15 * u.mV,
444
+ ... target=V,
445
+ ... refractory=refractory
446
+ ... )
447
+
448
+ Notes
449
+ -----
450
+ - The function automatically switches between normal approximation and binomial
451
+ sampling based on the input parameters to optimize computation efficiency.
452
+ - For large numbers of inputs, the normal approximation provides significant
453
+ performance improvements.
454
+ - The weight parameter is applied uniformly to all generated spikes.
455
+ - When refractory is provided, the corresponding target elements are not updated.
249
456
  """
250
457
  freq = maybe_state(freq)
251
458
  weight = maybe_state(weight)