brainstate 0.1.3__py2.py3-none-any.whl → 0.1.4__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +1 -16
  3. brainstate/compile/_jit.py +14 -5
  4. brainstate/compile/_make_jaxpr.py +78 -22
  5. brainstate/compile/_make_jaxpr_test.py +13 -2
  6. brainstate/graph/_graph_node.py +1 -1
  7. brainstate/graph/_graph_operation.py +4 -4
  8. brainstate/mixin.py +31 -2
  9. brainstate/nn/__init__.py +8 -5
  10. brainstate/nn/_delay.py +13 -1
  11. brainstate/nn/_dropout.py +5 -4
  12. brainstate/nn/_dynamics.py +39 -44
  13. brainstate/nn/{_fixedprob_mv.py → _fixedprob.py} +95 -24
  14. brainstate/nn/_linear_mv.py +1 -1
  15. brainstate/nn/_module.py +5 -5
  16. brainstate/nn/_projection.py +190 -98
  17. brainstate/nn/_synapse.py +5 -9
  18. brainstate/nn/_synaptic_projection.py +376 -86
  19. brainstate/surrogate.py +1 -1
  20. brainstate/typing.py +1 -1
  21. brainstate/util/__init__.py +14 -14
  22. brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
  23. {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/METADATA +1 -1
  24. {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/RECORD +35 -35
  25. /brainstate/nn/{_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
  26. /brainstate/util/{_caller.py → caller.py} +0 -0
  27. /brainstate/util/{_error.py → error.py} +0 -0
  28. /brainstate/util/{_others.py → others.py} +0 -0
  29. /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
  30. /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
  31. /brainstate/util/{_scaling.py → scaling.py} +0 -0
  32. /brainstate/util/{_struct.py → struct.py} +0 -0
  33. {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/LICENSE +0 -0
  34. {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/WHEEL +0 -0
  35. {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/top_level.txt +0 -0
@@ -13,25 +13,31 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Union, Callable, Optional
16
+ from typing import Callable, Union
17
+ from typing import Optional
18
+
19
+ import brainevent
20
+ import brainunit as u
17
21
 
18
22
  from brainstate._state import State
19
- from brainstate.mixin import AlignPost, ParamDescriber, BindCondData, JointTypes
20
- from brainstate.util._others import get_unique_name
23
+ from brainstate.mixin import BindCondData, JointTypes
24
+ from brainstate.mixin import ParamDescriber, AlignPost
25
+ from brainstate.util.others import get_unique_name
21
26
  from ._collective_ops import call_order
22
- from ._dynamics import Dynamics, maybe_init_prefetch, Prefetch, PrefetchDelayAt
27
+ from ._dynamics import Dynamics, Projection, maybe_init_prefetch, Prefetch, PrefetchDelayAt
23
28
  from ._module import Module
29
+ from ._stp import ShortTermPlasticity
30
+ from ._synapse import Synapse
24
31
  from ._synouts import SynOut
25
32
 
26
33
  __all__ = [
27
34
  'AlignPostProj',
28
35
  'DeltaProj',
29
36
  'CurrentProj',
30
- ]
31
-
32
37
 
33
- class Interaction(Module):
34
- __module__ = 'brainstate.nn'
38
+ 'align_pre_projection',
39
+ 'align_post_projection',
40
+ ]
35
41
 
36
42
 
37
43
  def _check_modules(*modules):
@@ -101,7 +107,7 @@ class _AlignPost(Module):
101
107
  self.out.bind_cond(self.syn(*args, **kwargs))
102
108
 
103
109
 
104
- class AlignPostProj(Interaction):
110
+ class AlignPostProj(Projection):
105
111
  """
106
112
  Align-post projection of the neural network.
107
113
 
@@ -113,7 +119,6 @@ class AlignPostProj(Interaction):
113
119
  Note that this projection needs the manual input of pre-synaptic spikes.
114
120
 
115
121
  >>> import brainstate
116
- >>> import brainevent
117
122
  >>> import brainunit as u
118
123
  >>> n_exc = 3200
119
124
  >>> n_inh = 800
@@ -124,16 +129,14 @@ class AlignPostProj(Interaction):
124
129
  ... tau=20. * u.ms, tau_ref=5. * u.ms,
125
130
  ... V_initializer=brainstate.init.Normal(-55., 2., unit=u.mV)
126
131
  ... )
127
- >>> pop.reset_state()
132
+ >>> pop.init_state()
128
133
  >>> E = brainstate.nn.AlignPostProj(
129
- ... comm=brainevent.nn.FixedProb(n_exc, num, prob=80 / num, weight=1.62 * u.mS),
134
+ ... comm=brainstate.nn.FixedNumConn(n_exc, num, prob=80 / num, weight=1.62 * u.mS),
130
135
  ... syn=brainstate.nn.Expon.desc(num, tau=5. * u.ms),
131
136
  ... out=brainstate.nn.CUBA.desc(scale=u.volt),
132
137
  ... post=pop
133
138
  ... )
134
- >>> exe_current = E(pop.spike.value)
135
-
136
-
139
+ >>> exe_current = E(pop.get_spike())
137
140
 
138
141
  """
139
142
  __module__ = 'brainstate.nn'
@@ -226,48 +229,47 @@ class AlignPostProj(Interaction):
226
229
  self.out.bind_cond(conductance)
227
230
 
228
231
 
229
- class DeltaProj(Interaction):
230
- r"""Full-chain of the synaptic projection for the Delta synapse model.
231
-
232
- The synaptic projection requires the input is the spiking data, otherwise
233
- the synapse is not the Delta synapse model.
234
-
235
- The ``full-chain`` means that the model needs to provide all information needed for a projection,
236
- including ``pre`` -> ``delay`` -> ``comm`` -> ``post``.
237
-
238
- **Model Descriptions**
239
-
240
- .. math::
241
-
242
- I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \delta(t-t_j-D)
243
-
244
- where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength,
245
- :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`,
246
- :math:`C` the set of neurons connected to the post-synaptic neuron,
247
- and :math:`D` the transmission delay of chemical synapses.
248
- For simplicity, the rise and decay phases of post-synaptic currents are
249
- omitted in this model.
250
-
251
- # brainstate.nn.DeltaInteraction(
252
- # LIF().prefetch('V'), bst.surrogate.ReluGrad(), comm, post
253
- # )
254
-
255
- Args:
256
- pre: The pre-synaptic neuron group.
257
- delay: The synaptic delay.
258
- comm: DynamicalSystem. The synaptic communication.
259
- post: DynamicalSystem. The post-synaptic neuron group.
232
+ class DeltaProj(Projection):
260
233
  """
234
+ Delta-based projection of the neural network.
235
+
236
+ This projection directly applies delta inputs to post-synaptic neurons without intervening
237
+ synaptic dynamics. It processes inputs through optional prefetch modules, applies a communication model,
238
+ and adds the result directly as a delta input to the post-synaptic population.
239
+
240
+ Parameters
241
+ ----------
242
+ *prefetch : State or callable
243
+ Optional prefetch modules to process input before communication.
244
+ comm : callable
245
+ Communication model that determines how signals are transmitted.
246
+ post : Dynamics
247
+ Post-synaptic neural population to receive the delta inputs.
248
+ label : Optional[str], default=None
249
+ Optional label for the projection to identify it in the post-synaptic population.
261
250
 
251
+ Examples
252
+ --------
253
+ >>> import brainstate
254
+ >>> import brainunit as u
255
+ >>> n_neurons = 100
256
+ >>> pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
257
+ >>> pop.init_state()
258
+ >>> delta_input = brainstate.nn.DeltaProj(
259
+ ... comm=lambda x: x * 10.0*u.mV,
260
+ ... post=pop
261
+ ... )
262
+ >>> delta_input(1.0) # Apply voltage increment directly
263
+ """
262
264
  __module__ = 'brainstate.nn'
263
265
 
264
- def __init__(self, *modules, comm: Callable, post: Dynamics, label=None):
266
+ def __init__(self, *prefetch, comm: Callable, post: Dynamics, label=None):
265
267
  super().__init__(name=get_unique_name(self.__class__.__name__))
266
268
 
267
269
  self.label = label
268
270
 
269
271
  # checking modules
270
- self.modules = _check_modules(*modules)
272
+ self.prefetches = _check_modules(*prefetch)
271
273
 
272
274
  # checking communication model
273
275
  if not callable(comm):
@@ -285,58 +287,69 @@ class DeltaProj(Interaction):
285
287
 
286
288
  @call_order(2)
287
289
  def init_state(self, *args, **kwargs):
288
- for module in self.modules:
289
- maybe_init_prefetch(module, *args, **kwargs)
290
+ for prefetch in self.prefetches:
291
+ maybe_init_prefetch(prefetch, *args, **kwargs)
290
292
 
291
293
  def update(self, *x):
292
- for module in self.modules:
294
+ for module in self.prefetches:
293
295
  x = (call_module(module, *x),)
294
296
  assert len(x) == 1, f'The output of the modules should be a single value, but got {x}.'
295
297
  x = self.comm(x[0])
296
298
  self.post.add_delta_input(self.name, x, label=self.label)
297
299
 
298
300
 
299
- class CurrentProj(Interaction):
301
+ class CurrentProj(Projection):
300
302
  """
301
- Full-chain synaptic projection with the align-pre reduction and delay+synapse updating and merging.
302
-
303
- The ``full-chain`` means that the model needs to provide all information needed for a projection,
304
- including ``pre`` -> ``delay`` -> ``syn`` -> ``comm`` -> ``out`` -> ``post``.
305
- Note here, compared to ``FullProjAlignPreSD``, the ``delay`` and ``syn`` are exchanged.
306
-
307
- The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
308
-
309
- The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the
310
- spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current.
311
-
312
- The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
313
- parameters (such like time constants) will also share the same synaptic variables.
303
+ Current-based projection of the neural network.
304
+
305
+ This projection directly modulates post-synaptic currents without separate synaptic dynamics.
306
+ It processes inputs through optional prefetch modules, applies a communication model,
307
+ and binds the result to the output model which is then added as a current input to the post-synaptic population.
308
+
309
+ Parameters
310
+ ----------
311
+ *prefetch : State or callable
312
+ Optional prefetch modules to process input before communication.
313
+ The last element must be an instance of Prefetch or PrefetchDelayAt if any are provided.
314
+ comm : callable
315
+ Communication model that determines how signals are transmitted.
316
+ out : SynOut
317
+ Output model that converts communication results to post-synaptic currents.
318
+ post : Dynamics
319
+ Post-synaptic neural population to receive the currents.
314
320
 
315
- Neither ``FullProjAlignPreDS`` nor ``FullProjAlignPreSD`` facilitates the event-driven computation.
316
- This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
317
- than the spiking. To facilitate the event-driven computation, please use align post projections.
318
-
319
- Args:
320
- prefetch: The synaptic dynamics.
321
- comm: The synaptic communication.
322
- out: The synaptic output.
323
- post: The post-synaptic neuron group.
321
+ Examples
322
+ --------
323
+ >>> import brainstate
324
+ >>> import brainunit as u
325
+ >>> n_neurons = 100
326
+ >>> pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
327
+ >>> pop.init_state()
328
+ >>> current_input = brainstate.nn.CurrentProj(
329
+ ... comm=lambda x: x * 0.5,
330
+ ... out=brainstate.nn.CUBA(scale=1.0*u.nA),
331
+ ... post=pop
332
+ ... )
333
+ >>> current_input(0.2) # Apply external current
324
334
  """
325
335
  __module__ = 'brainstate.nn'
326
336
 
327
337
  def __init__(
328
338
  self,
329
- prefetch: Union[Prefetch, PrefetchDelayAt],
339
+ *prefetch,
330
340
  comm: Callable,
331
341
  out: SynOut,
332
342
  post: Dynamics,
333
343
  ):
334
344
  super().__init__(name=get_unique_name(self.__class__.__name__))
335
345
 
336
- # pre-synaptic neuron group
337
- if not isinstance(prefetch, (Prefetch, PrefetchDelayAt)):
338
- raise TypeError(f'The pre should be a Prefetch or PrefetchDelayAt, but got {prefetch}.')
346
+ # check prefetch
339
347
  self.prefetch = prefetch
348
+ if len(self.prefetch) > 0 and not isinstance(prefetch[-1], (Prefetch, PrefetchDelayAt)):
349
+ raise TypeError(
350
+ f'The last element of prefetch should be an instance of {Prefetch} or {PrefetchDelayAt}, '
351
+ f'but got {prefetch[-1]}.'
352
+ )
340
353
 
341
354
  # check out
342
355
  if not isinstance(out, SynOut):
@@ -354,41 +367,120 @@ class CurrentProj(Interaction):
354
367
 
355
368
  @call_order(2)
356
369
  def init_state(self, *args, **kwargs):
357
- maybe_init_prefetch(self.prefetch, *args, **kwargs)
370
+ for prefetch in self.prefetch:
371
+ maybe_init_prefetch(prefetch, *args, **kwargs)
358
372
 
359
373
  def update(self, *x):
360
- x = self.prefetch(*x)
361
- x = self.comm(x)
374
+ for prefetch in self.prefetch:
375
+ x = (call_module(prefetch, *x),)
376
+ x = self.comm(*x)
362
377
  self.out.bind_cond(x)
363
378
 
364
379
 
365
- class RawProj(Interaction):
380
+ class align_pre_projection(Projection):
366
381
  """
382
+ Represents a pre-synaptic alignment projection mechanism.
383
+
384
+ This class inherits from the `Projection` base class and is designed to
385
+ manage the pre-synaptic alignment process in neural network simulations.
386
+ It takes into account pre-synaptic dynamics, synaptic properties, delays,
387
+ communication functions, synaptic outputs, post-synaptic dynamics, and
388
+ short-term plasticity.
389
+
390
+ Attributes:
391
+ pre (Dynamics): The pre-synaptic dynamics object.
392
+ syn (Synapse): The synaptic object after pre-synaptic alignment.
393
+ delay (u.Quantity[u.second]): The output delay from the synapse.
394
+ projection (CurrentProj): The current projection object handling communication,
395
+ output, and post-synaptic dynamics.
396
+ stp (ShortTermPlasticity, optional): The short-term plasticity object,
397
+ defaults to None.
367
398
  """
368
- __module__ = 'brainstate.nn'
369
399
 
370
400
  def __init__(
371
401
  self,
402
+ *spike_generator,
403
+ syn: Dynamics,
372
404
  comm: Callable,
373
405
  out: SynOut,
374
406
  post: Dynamics,
407
+ stp: ShortTermPlasticity = None,
375
408
  ):
376
- super().__init__(name=get_unique_name(self.__class__.__name__))
409
+ super().__init__()
377
410
 
378
- # check out
379
- if not isinstance(out, SynOut):
380
- raise TypeError(f'The out should be a SynOut, but got {out}.')
381
- self.out = out
411
+ self.spike_generator = _check_modules(*spike_generator)
412
+ self.projection = CurrentProj(comm=comm, out=out, post=post)
413
+ self.syn = syn
414
+ self.stp = stp
382
415
 
383
- # check post
384
- if not isinstance(post, Dynamics):
385
- raise TypeError(f'The post should be a Dynamics, but got {post}.')
386
- self.post = post
387
- post.add_current_input(self.name, out)
416
+ @call_order(2)
417
+ def init_state(self, *args, **kwargs):
418
+ for module in self.spike_generator:
419
+ maybe_init_prefetch(module, *args, **kwargs)
388
420
 
389
- # output initialization
390
- self.comm = comm
421
+ def update(self, *x):
422
+ for fun in self.spike_generator:
423
+ x = fun(*x)
424
+ if isinstance(x, (tuple, list)):
425
+ x = tuple(x)
426
+ else:
427
+ x = (x,)
428
+ assert len(x) == 1, "Spike generator must return a single value or a tuple/list of values"
429
+ x = brainevent.BinaryArray(x[0]) # Ensure input is a BinaryFloat for spike generation
430
+ if self.stp is not None:
431
+ x = brainevent.MaskedFloat(self.stp(x)) # Ensure STP output is a MaskedFloat
432
+ x = self.syn(x) # Apply pre-synaptic alignment
433
+ return self.projection(x)
434
+
435
+
436
+ class align_post_projection(Projection):
437
+ """
438
+ Represents a post-synaptic alignment projection mechanism.
391
439
 
392
- def update(self, x):
393
- x = self.comm(x)
394
- self.out.bind_cond(x)
440
+ This class inherits from the `Projection` base class and is designed to
441
+ manage the post-synaptic alignment process in neural network simulations.
442
+ It takes into account spike generators, communication functions, synaptic
443
+ properties, synaptic outputs, post-synaptic dynamics, and short-term plasticity.
444
+
445
+ Args:
446
+ *spike_generator: Callable(s) that generate spike events or transform input spikes.
447
+ comm (Callable): Communication function for the projection.
448
+ syn (Union[AlignPost, ParamDescriber[AlignPost]]): The post-synaptic alignment object or its parameter describer.
449
+ out (Union[SynOut, ParamDescriber[SynOut]]): The synaptic output object or its parameter describer.
450
+ post (Dynamics): The post-synaptic dynamics object.
451
+ stp (ShortTermPlasticity, optional): The short-term plasticity object, defaults to None.
452
+
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ *spike_generator,
458
+ comm: Callable,
459
+ syn: Union[AlignPost, ParamDescriber[AlignPost]],
460
+ out: Union[SynOut, ParamDescriber[SynOut]],
461
+ post: Dynamics,
462
+ stp: ShortTermPlasticity = None,
463
+ ):
464
+ super().__init__()
465
+
466
+ self.spike_generator = _check_modules(*spike_generator)
467
+ self.projection = AlignPostProj(comm=comm, syn=syn, out=out, post=post)
468
+ self.stp = stp
469
+
470
+ @call_order(2)
471
+ def init_state(self, *args, **kwargs):
472
+ for module in self.spike_generator:
473
+ maybe_init_prefetch(module, *args, **kwargs)
474
+
475
+ def update(self, *x):
476
+ for fun in self.spike_generator:
477
+ x = fun(*x)
478
+ if isinstance(x, (tuple, list)):
479
+ x = tuple(x)
480
+ else:
481
+ x = (x,)
482
+ assert len(x) == 1, "Spike generator must return a single value or a tuple/list of values"
483
+ x = brainevent.BinaryArray(x[0]) # Ensure input is a BinaryFloat for spike generation
484
+ if self.stp is not None:
485
+ x = brainevent.MaskedFloat(self.stp(x)) # Ensure STP output is a MaskedFloat
486
+ return self.projection(x)
brainstate/nn/_synapse.py CHANGED
@@ -307,9 +307,6 @@ class Alpha(Synapse):
307
307
  self.h.value = self.sum_delta_inputs(h)
308
308
  if x is not None:
309
309
  self.h.value += x
310
- return self.update_return()
311
-
312
- def update_return(self) -> PyTree:
313
310
  return self.g.value
314
311
 
315
312
 
@@ -394,7 +391,7 @@ class AMPA(Synapse):
394
391
  beta: ArrayLike = 0.18 / u.ms,
395
392
  T: ArrayLike = 0.5 * u.mM,
396
393
  T_dur: ArrayLike = 0.5 * u.ms,
397
- g_initializer: ArrayLike | Callable = init.ZeroInit(),
394
+ g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
398
395
  ):
399
396
  super().__init__(name=name, in_size=in_size)
400
397
 
@@ -413,14 +410,12 @@ class AMPA(Synapse):
413
410
  self.g.value = init.param(self.g_initializer, self.varshape, batch_or_mode)
414
411
  self.spike_arrival_time.value = init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_or_mode)
415
412
 
416
- def dg(self, g, t, TT):
417
- return self.alpha * TT * (1 - g) - self.beta * g
418
-
419
413
  def update(self, pre_spike):
420
414
  t = environ.get('t')
421
415
  self.spike_arrival_time.value = u.math.where(pre_spike, t, self.spike_arrival_time.value)
422
416
  TT = ((t - self.spike_arrival_time.value) < self.T_duration) * self.T
423
- self.g.value = exp_euler_step(self.dg, self.g.value, t, TT)
417
+ dg = lambda g: self.alpha * TT * (1 - g) - self.beta * g
418
+ self.g.value = exp_euler_step(dg, self.g.value)
424
419
  return self.update_return()
425
420
 
426
421
  def update_return(self) -> PyTree:
@@ -507,7 +502,7 @@ class GABAa(AMPA):
507
502
  beta: ArrayLike = 0.18 / u.ms,
508
503
  T: ArrayLike = 1.0 * u.mM,
509
504
  T_dur: ArrayLike = 1.0 * u.ms,
510
- g_initializer: ArrayLike | Callable = init.ZeroInit(),
505
+ g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
511
506
  ):
512
507
  super().__init__(
513
508
  alpha=alpha,
@@ -518,3 +513,4 @@ class GABAa(AMPA):
518
513
  in_size=in_size,
519
514
  g_initializer=g_initializer
520
515
  )
516
+