brainstate 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. brainstate/__init__.py +45 -0
  2. brainstate/_module.py +1466 -0
  3. brainstate/_module_test.py +133 -0
  4. brainstate/_state.py +378 -0
  5. brainstate/_state_test.py +41 -0
  6. brainstate/_utils.py +21 -0
  7. brainstate/environ.py +375 -0
  8. brainstate/functional/__init__.py +25 -0
  9. brainstate/functional/_activations.py +754 -0
  10. brainstate/functional/_normalization.py +69 -0
  11. brainstate/functional/_spikes.py +90 -0
  12. brainstate/init/__init__.py +26 -0
  13. brainstate/init/_base.py +36 -0
  14. brainstate/init/_generic.py +175 -0
  15. brainstate/init/_random_inits.py +489 -0
  16. brainstate/init/_regular_inits.py +109 -0
  17. brainstate/math/__init__.py +21 -0
  18. brainstate/math/_einops.py +787 -0
  19. brainstate/math/_einops_parsing.py +169 -0
  20. brainstate/math/_einops_parsing_test.py +126 -0
  21. brainstate/math/_einops_test.py +346 -0
  22. brainstate/math/_misc.py +298 -0
  23. brainstate/math/_misc_test.py +58 -0
  24. brainstate/mixin.py +373 -0
  25. brainstate/mixin_test.py +73 -0
  26. brainstate/nn/__init__.py +68 -0
  27. brainstate/nn/_base.py +248 -0
  28. brainstate/nn/_connections.py +686 -0
  29. brainstate/nn/_dynamics.py +406 -0
  30. brainstate/nn/_elementwise.py +1437 -0
  31. brainstate/nn/_misc.py +132 -0
  32. brainstate/nn/_normalizations.py +389 -0
  33. brainstate/nn/_others.py +100 -0
  34. brainstate/nn/_poolings.py +1228 -0
  35. brainstate/nn/_poolings_test.py +231 -0
  36. brainstate/nn/_projection/__init__.py +32 -0
  37. brainstate/nn/_projection/_align_post.py +528 -0
  38. brainstate/nn/_projection/_align_pre.py +599 -0
  39. brainstate/nn/_projection/_delta.py +241 -0
  40. brainstate/nn/_projection/_utils.py +17 -0
  41. brainstate/nn/_projection/_vanilla.py +101 -0
  42. brainstate/nn/_rate_rnns.py +393 -0
  43. brainstate/nn/_readout.py +130 -0
  44. brainstate/nn/_synouts.py +166 -0
  45. brainstate/nn/functional/__init__.py +25 -0
  46. brainstate/nn/functional/_activations.py +754 -0
  47. brainstate/nn/functional/_normalization.py +69 -0
  48. brainstate/nn/functional/_spikes.py +90 -0
  49. brainstate/nn/init/__init__.py +26 -0
  50. brainstate/nn/init/_base.py +36 -0
  51. brainstate/nn/init/_generic.py +175 -0
  52. brainstate/nn/init/_random_inits.py +489 -0
  53. brainstate/nn/init/_regular_inits.py +109 -0
  54. brainstate/nn/surrogate.py +1740 -0
  55. brainstate/optim/__init__.py +23 -0
  56. brainstate/optim/_lr_scheduler.py +486 -0
  57. brainstate/optim/_lr_scheduler_test.py +36 -0
  58. brainstate/optim/_sgd_optimizer.py +1148 -0
  59. brainstate/random.py +5148 -0
  60. brainstate/random_test.py +576 -0
  61. brainstate/surrogate.py +1740 -0
  62. brainstate/transform/__init__.py +36 -0
  63. brainstate/transform/_autograd.py +585 -0
  64. brainstate/transform/_autograd_test.py +1183 -0
  65. brainstate/transform/_control.py +665 -0
  66. brainstate/transform/_controls_test.py +220 -0
  67. brainstate/transform/_jit.py +239 -0
  68. brainstate/transform/_jit_error.py +158 -0
  69. brainstate/transform/_jit_test.py +102 -0
  70. brainstate/transform/_make_jaxpr.py +573 -0
  71. brainstate/transform/_make_jaxpr_test.py +133 -0
  72. brainstate/transform/_progress_bar.py +113 -0
  73. brainstate/typing.py +69 -0
  74. brainstate/util.py +747 -0
  75. brainstate-0.0.1.dist-info/LICENSE +202 -0
  76. brainstate-0.0.1.dist-info/METADATA +101 -0
  77. brainstate-0.0.1.dist-info/RECORD +79 -0
  78. brainstate-0.0.1.dist-info/WHEEL +6 -0
  79. brainstate-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,599 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from typing import Optional, Union
18
+
19
+ from brainstate._module import (Module, DelayAccess, Projection,
20
+ ExtendedUpdateWithBA, ReceiveInputProj,
21
+ register_delay_of_target)
22
+ from brainstate._utils import set_module_as
23
+ from brainstate.mixin import (DelayedInitializer, BindCondData, UpdateReturn, Mode, AllOfTypes)
24
+ from ._utils import is_instance
25
+
26
+ __all__ = [
27
+ 'FullProjAlignPreSDMg', 'FullProjAlignPreDSMg',
28
+ 'FullProjAlignPreSD', 'FullProjAlignPreDS',
29
+ ]
30
+
31
+
32
+ def align_pre_add_bef_update(
33
+ syn_desc: DelayedInitializer,
34
+ delay_at,
35
+ delay_cls: ExtendedUpdateWithBA,
36
+ proj_name: str = None
37
+ ):
38
+ _syn_id = f'Delay({str(delay_at)}) // {syn_desc.identifier}'
39
+ if not delay_cls.has_before_update(_syn_id):
40
+ # delay
41
+ delay_access = DelayAccess(delay_cls, delay_at, delay_entry=proj_name)
42
+ # synapse
43
+ syn_cls = syn_desc()
44
+ # add to "after_updates"
45
+ delay_cls.add_before_update(_syn_id, _AlignPreMg(delay_access, syn_cls))
46
+ syn = delay_cls.get_before_update(_syn_id).syn
47
+ return syn
48
+
49
+
50
+ class _AlignPreMg(Module):
51
+ def __init__(self, access, syn):
52
+ super().__init__()
53
+ self.access = access
54
+ self.syn = syn
55
+
56
+ def update(self, *args, **kwargs):
57
+ return self.syn(self.access())
58
+
59
+
60
+ @set_module_as('brainstate.nn')
61
+ class FullProjAlignPreSDMg(Projection):
62
+ """Full-chain synaptic projection with the align-pre reduction and synapse+delay updating and merging.
63
+
64
+ The ``full-chain`` means that the model needs to provide all information needed for a projection,
65
+ including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
66
+
67
+ The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
68
+
69
+ The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the
70
+ synapse states to the delay model, and finally computes the synaptic current.
71
+
72
+ The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
73
+ parameters (such like time constants) will also share the same synaptic variables.
74
+
75
+ Neither ``FullProjAlignPreSDMg`` nor ``FullProjAlignPreDSMg`` facilitates the event-driven computation.
76
+ This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
77
+ than the spiking. To facilitate the event-driven computation, please use align post projections.
78
+
79
+ To simulate an E/I balanced network model:
80
+
81
+ .. code-block:: python
82
+
83
+ class EINet(bp.DynSysGroup):
84
+ def __init__(self):
85
+ super().__init__()
86
+ ne, ni = 3200, 800
87
+ self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
88
+ V_initializer=bp.init.Normal(-55., 2.))
89
+ self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
90
+ V_initializer=bp.init.Normal(-55., 2.))
91
+ self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E,
92
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
93
+ delay=0.1,
94
+ comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
95
+ out=bp.dyn.COBA(E=0.),
96
+ post=self.E)
97
+ self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E,
98
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
99
+ delay=0.1,
100
+ comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
101
+ out=bp.dyn.COBA(E=0.),
102
+ post=self.I)
103
+ self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I,
104
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
105
+ delay=0.1,
106
+ comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
107
+ out=bp.dyn.COBA(E=-80.),
108
+ post=self.E)
109
+ self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I,
110
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
111
+ delay=0.1,
112
+ comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
113
+ out=bp.dyn.COBA(E=-80.),
114
+ post=self.I)
115
+
116
+ def update(self, inp):
117
+ self.E2E()
118
+ self.E2I()
119
+ self.I2E()
120
+ self.I2I()
121
+ self.E(inp)
122
+ self.I(inp)
123
+ return self.E.spike
124
+
125
+ model = EINet()
126
+ indices = bm.arange(1000)
127
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
128
+ bp.visualize.raster_plot(indices, spks, show=True)
129
+
130
+
131
+ Args:
132
+ pre: The pre-synaptic neuron group.
133
+ syn: The synaptic dynamics.
134
+ delay: The synaptic delay.
135
+ comm: The synaptic communication.
136
+ out: The synaptic output.
137
+ post: The post-synaptic neuron group.
138
+ name: str. The projection name.
139
+ mode: Mode. The computing mode.
140
+ """
141
+
142
+ _invisible_nodes = ['pre', 'syn', 'delay', 'post']
143
+
144
+ def __init__(
145
+ self,
146
+ pre: ExtendedUpdateWithBA,
147
+ syn: DelayedInitializer[UpdateReturn],
148
+ delay: Union[None, int, float],
149
+ comm: Module,
150
+ out: BindCondData,
151
+ post: ReceiveInputProj,
152
+ out_label: Optional[str] = None,
153
+ name: Optional[str] = None,
154
+ mode: Optional[Mode] = None,
155
+ ):
156
+ super().__init__(name=name, mode=mode)
157
+
158
+ # synaptic models
159
+ is_instance(pre, ExtendedUpdateWithBA)
160
+ is_instance(syn, DelayedInitializer[UpdateReturn])
161
+ is_instance(comm, Module)
162
+ is_instance(out, BindCondData)
163
+ is_instance(post, ReceiveInputProj)
164
+ self.comm = comm
165
+ self.out = out
166
+
167
+ # synapse initialization
168
+ if not pre.has_after_update(syn.identifier):
169
+ syn_cls = syn()
170
+ pre.add_after_update(syn.identifier, syn_cls)
171
+ self.syn = pre.get_after_update(syn.identifier)
172
+
173
+ # delay initialization
174
+ if delay is not None and delay > 0.:
175
+ delay_cls = register_delay_of_target(self.syn)
176
+ self.has_delay = True
177
+ self.delay = delay_cls
178
+ else:
179
+ self.has_delay = False
180
+ self.delay = None
181
+
182
+ # output initialization
183
+ post.add_input_fun(self.name, out, label=out_label)
184
+
185
+ # references
186
+ self.pre = pre
187
+ self.post = post
188
+
189
+ def update(self, x=None):
190
+ if x is None:
191
+ if self.has_delay:
192
+ x = self.delay.at(self.name)
193
+ else:
194
+ x = self.syn.update_return()
195
+ current = self.comm(x)
196
+ self.out.bind_cond(current)
197
+ return current
198
+
199
+
200
+ @set_module_as('brainstate.nn')
201
+ class FullProjAlignPreDSMg(Projection):
202
+ """Full-chain synaptic projection with the align-pre reduction and delay+synapse updating and merging.
203
+
204
+ The ``full-chain`` means that the model needs to provide all information needed for a projection,
205
+ including ``pre`` -> ``delay`` -> ``syn`` -> ``comm`` -> ``out`` -> ``post``.
206
+ Note here, compared to ``FullProjAlignPreSDMg``, the ``delay`` and ``syn`` are exchanged.
207
+
208
+ The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
209
+
210
+ The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the
211
+ spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current.
212
+
213
+ The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
214
+ parameters (such like time constants) will also share the same synaptic variables.
215
+
216
+ Neither ``FullProjAlignPreDSMg`` nor ``FullProjAlignPreSDMg`` facilitates the event-driven computation.
217
+ This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
218
+ than the spiking. To facilitate the event-driven computation, please use align post projections.
219
+
220
+
221
+ To simulate an E/I balanced network model:
222
+
223
+ .. code-block:: python
224
+
225
+ class EINet(bp.DynSysGroup):
226
+ def __init__(self):
227
+ super().__init__()
228
+ ne, ni = 3200, 800
229
+ self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
230
+ V_initializer=bp.init.Normal(-55., 2.))
231
+ self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
232
+ V_initializer=bp.init.Normal(-55., 2.))
233
+ self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E,
234
+ delay=0.1,
235
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
236
+ comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
237
+ out=bp.dyn.COBA(E=0.),
238
+ post=self.E)
239
+ self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E,
240
+ delay=0.1,
241
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
242
+ comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
243
+ out=bp.dyn.COBA(E=0.),
244
+ post=self.I)
245
+ self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I,
246
+ delay=0.1,
247
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
248
+ comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
249
+ out=bp.dyn.COBA(E=-80.),
250
+ post=self.E)
251
+ self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I,
252
+ delay=0.1,
253
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
254
+ comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
255
+ out=bp.dyn.COBA(E=-80.),
256
+ post=self.I)
257
+
258
+ def update(self, inp):
259
+ self.E2E()
260
+ self.E2I()
261
+ self.I2E()
262
+ self.I2I()
263
+ self.E(inp)
264
+ self.I(inp)
265
+ return self.E.spike
266
+
267
+ model = EINet()
268
+ indices = bm.arange(1000)
269
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
270
+ bp.visualize.raster_plot(indices, spks, show=True)
271
+
272
+
273
+ Args:
274
+ pre: The pre-synaptic neuron group.
275
+ delay: The synaptic delay.
276
+ syn: The synaptic dynamics.
277
+ comm: The synaptic communication.
278
+ out: The synaptic output.
279
+ post: The post-synaptic neuron group.
280
+ name: str. The projection name.
281
+ mode: Mode. The computing mode.
282
+ """
283
+ _invisible_nodes = ['pre', 'syn', 'delay', 'post']
284
+
285
+ def __init__(
286
+ self,
287
+ pre: AllOfTypes[ExtendedUpdateWithBA, UpdateReturn],
288
+ delay: Union[None, int, float],
289
+ syn: DelayedInitializer[UpdateReturn],
290
+ comm: Module,
291
+ out: BindCondData,
292
+ post: ReceiveInputProj,
293
+ out_label: Optional[str] = None,
294
+ name: Optional[str] = None,
295
+ mode: Optional[Mode] = None,
296
+ ):
297
+ super().__init__(name=name, mode=mode)
298
+
299
+ # synaptic models
300
+ is_instance(pre, AllOfTypes[ExtendedUpdateWithBA, UpdateReturn])
301
+ is_instance(syn, DelayedInitializer[Module])
302
+ is_instance(comm, Module)
303
+ is_instance(out, BindCondData)
304
+ is_instance(post, ReceiveInputProj)
305
+ self.comm = comm
306
+ self.out = out
307
+
308
+ # delay initialization
309
+ if delay is not None and delay > 0.:
310
+ delay_cls = register_delay_of_target(pre)
311
+ self.has_delay = True
312
+ self.delay = delay_cls
313
+ # synapse initialization
314
+ self.syn = align_pre_add_bef_update(syn, delay, delay_cls, self.name)
315
+ else:
316
+ self.has_delay = False
317
+ self.delay = None
318
+ if not pre.has_after_update(syn.identifier):
319
+ syn_cls = syn()
320
+ pre.add_after_update(syn.identifier, syn_cls)
321
+ self.syn = pre.get_after_update(syn.identifier)
322
+
323
+ # output initialization
324
+ post.add_input_fun(self.name, out, label=out_label)
325
+
326
+ # references
327
+ self.pre = pre
328
+ self.post = post
329
+
330
+ def update(self):
331
+ x = self.syn.update_return()
332
+ current = self.comm(x)
333
+ self.out.bind_cond(current)
334
+ return current
335
+
336
+
337
+ @set_module_as('brainstate.nn')
338
+ class FullProjAlignPreSD(Projection):
339
+ """Full-chain synaptic projection with the align-pre reduction and synapse+delay updating.
340
+
341
+ The ``full-chain`` means that the model needs to provide all information needed for a projection,
342
+ including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
343
+
344
+ The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
345
+
346
+ The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the
347
+ synapse states to the delay model, and finally computes the synaptic current.
348
+
349
+ Neither ``FullProjAlignPreSD`` nor ``FullProjAlignPreDS`` facilitates the event-driven computation.
350
+ This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
351
+ than the spiking. To facilitate the event-driven computation, please use align post projections.
352
+
353
+
354
+ To simulate an E/I balanced network model:
355
+
356
+ .. code-block:: python
357
+
358
+ class EINet(bp.DynSysGroup):
359
+ def __init__(self):
360
+ super().__init__()
361
+ ne, ni = 3200, 800
362
+ self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
363
+ V_initializer=bp.init.Normal(-55., 2.))
364
+ self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
365
+ V_initializer=bp.init.Normal(-55., 2.))
366
+ self.E2E = bp.dyn.FullProjAlignPreSD(pre=self.E,
367
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
368
+ delay=0.1,
369
+ comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
370
+ out=bp.dyn.COBA(E=0.),
371
+ post=self.E)
372
+ self.E2I = bp.dyn.FullProjAlignPreSD(pre=self.E,
373
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
374
+ delay=0.1,
375
+ comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
376
+ out=bp.dyn.COBA(E=0.),
377
+ post=self.I)
378
+ self.I2E = bp.dyn.FullProjAlignPreSD(pre=self.I,
379
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
380
+ delay=0.1,
381
+ comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
382
+ out=bp.dyn.COBA(E=-80.),
383
+ post=self.E)
384
+ self.I2I = bp.dyn.FullProjAlignPreSD(pre=self.I,
385
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
386
+ delay=0.1,
387
+ comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
388
+ out=bp.dyn.COBA(E=-80.),
389
+ post=self.I)
390
+
391
+ def update(self, inp):
392
+ self.E2E()
393
+ self.E2I()
394
+ self.I2E()
395
+ self.I2I()
396
+ self.E(inp)
397
+ self.I(inp)
398
+ return self.E.spike
399
+
400
+ model = EINet()
401
+ indices = bm.arange(1000)
402
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
403
+ bp.visualize.raster_plot(indices, spks, show=True)
404
+
405
+
406
+ Args:
407
+ pre: The pre-synaptic neuron group.
408
+ syn: The synaptic dynamics.
409
+ delay: The synaptic delay.
410
+ comm: The synaptic communication.
411
+ out: The synaptic output.
412
+ post: The post-synaptic neuron group.
413
+ name: str. The projection name.
414
+ mode: Mode. The computing mode.
415
+ """
416
+
417
+ _invisible_nodes = ['pre', 'post']
418
+
419
+ def __init__(
420
+ self,
421
+ pre: ExtendedUpdateWithBA,
422
+ syn: UpdateReturn,
423
+ delay: Union[None, int, float],
424
+ comm: Module,
425
+ out: BindCondData,
426
+ post: ReceiveInputProj,
427
+ out_label: Optional[str] = None,
428
+ name: Optional[str] = None,
429
+ mode: Optional[Mode] = None,
430
+ ):
431
+ super().__init__(name=name, mode=mode)
432
+
433
+ # synaptic models
434
+ is_instance(pre, ExtendedUpdateWithBA)
435
+ is_instance(syn, UpdateReturn)
436
+ is_instance(comm, Module)
437
+ is_instance(out, BindCondData)
438
+ is_instance(post, ReceiveInputProj)
439
+ self.comm = comm
440
+ self.syn = syn
441
+ self.out = out
442
+
443
+ # delay initialization
444
+ if delay is not None and delay > 0.:
445
+ delay_cls = register_delay_of_target(syn)
446
+ delay_cls.register_entry(self.name, delay)
447
+ self.delay = delay_cls
448
+ else:
449
+ self.delay = None
450
+
451
+ # output initialization
452
+ post.add_input_fun(self.name, out, label=out_label)
453
+
454
+ # references
455
+ self.pre = pre
456
+ self.post = post
457
+
458
+ def update(self):
459
+ if self.delay is not None:
460
+ self.delay(self.syn(self.pre.update_return()))
461
+ x = self.delay.at(self.name)
462
+ else:
463
+ x = self.syn(self.pre.update_return())
464
+ current = self.comm(x)
465
+ self.out.bind_cond(current)
466
+ return current
467
+
468
+
469
+ @set_module_as('brainstate.nn')
470
+ class FullProjAlignPreDS(Projection):
471
+ """Full-chain synaptic projection with the align-pre reduction and delay+synapse updating.
472
+
473
+ The ``full-chain`` means that the model needs to provide all information needed for a projection,
474
+ including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
475
+ Note here, compared to ``FullProjAlignPreSD``, the ``delay`` and ``syn`` are exchanged.
476
+
477
+ The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
478
+
479
+ The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the
480
+ spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current.
481
+
482
+ Neither ``FullProjAlignPreDS`` nor ``FullProjAlignPreSD`` facilitates the event-driven computation.
483
+ This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
484
+ than the spiking. To facilitate the event-driven computation, please use align post projections.
485
+
486
+
487
+ To simulate an E/I balanced network model:
488
+
489
+ .. code-block:: python
490
+
491
+ class EINet(bp.DynSysGroup):
492
+ def __init__(self):
493
+ super().__init__()
494
+ ne, ni = 3200, 800
495
+ self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
496
+ V_initializer=bp.init.Normal(-55., 2.))
497
+ self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
498
+ V_initializer=bp.init.Normal(-55., 2.))
499
+ self.E2E = bp.dyn.FullProjAlignPreDS(pre=self.E,
500
+ delay=0.1,
501
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
502
+ comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
503
+ out=bp.dyn.COBA(E=0.),
504
+ post=self.E)
505
+ self.E2I = bp.dyn.FullProjAlignPreDS(pre=self.E,
506
+ delay=0.1,
507
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
508
+ comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
509
+ out=bp.dyn.COBA(E=0.),
510
+ post=self.I)
511
+ self.I2E = bp.dyn.FullProjAlignPreDS(pre=self.I,
512
+ delay=0.1,
513
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
514
+ comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
515
+ out=bp.dyn.COBA(E=-80.),
516
+ post=self.E)
517
+ self.I2I = bp.dyn.FullProjAlignPreDS(pre=self.I,
518
+ delay=0.1,
519
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
520
+ comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
521
+ out=bp.dyn.COBA(E=-80.),
522
+ post=self.I)
523
+
524
+ def update(self, inp):
525
+ self.E2E()
526
+ self.E2I()
527
+ self.I2E()
528
+ self.I2I()
529
+ self.E(inp)
530
+ self.I(inp)
531
+ return self.E.spike
532
+
533
+ model = EINet()
534
+ indices = bm.arange(1000)
535
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
536
+ bp.visualize.raster_plot(indices, spks, show=True)
537
+
538
+
539
+ Args:
540
+ pre: The pre-synaptic neuron group.
541
+ delay: The synaptic delay.
542
+ syn: The synaptic dynamics.
543
+ comm: The synaptic communication.
544
+ out: The synaptic output.
545
+ post: The post-synaptic neuron group.
546
+ name: str. The projection name.
547
+ mode: Mode. The computing mode.
548
+ """
549
+
550
+ _invisible_nodes = ['pre', 'post', 'delay']
551
+
552
+ def __init__(
553
+ self,
554
+ pre: UpdateReturn,
555
+ delay: Union[None, int, float],
556
+ syn: Module,
557
+ comm: Module,
558
+ out: BindCondData,
559
+ post: ReceiveInputProj,
560
+ out_label: Optional[str] = None,
561
+ name: Optional[str] = None,
562
+ mode: Optional[Mode] = None,
563
+ ):
564
+ super().__init__(name=name, mode=mode)
565
+
566
+ # synaptic models
567
+ is_instance(pre, UpdateReturn)
568
+ is_instance(syn, Module)
569
+ is_instance(comm, Module)
570
+ is_instance(out, BindCondData)
571
+ is_instance(post, ReceiveInputProj)
572
+ self.comm = comm
573
+ self.syn = syn
574
+ self.out = out
575
+
576
+ # delay initialization
577
+ if delay is not None and delay > 0.:
578
+ delay_cls = register_delay_of_target(pre)
579
+ delay_cls.register_entry(self.name, delay)
580
+ self.delay = delay_cls
581
+ else:
582
+ self.delay = None
583
+
584
+ # output initialization
585
+ post.add_input_fun(self.name, out, label=out_label)
586
+
587
+ # references
588
+ self.pre = pre
589
+ self.post = post
590
+
591
+ def update(self, x=None):
592
+ if x is None:
593
+ if self.delay is not None:
594
+ x = self.delay.at(self.name)
595
+ else:
596
+ x = self.pre.update_return()
597
+ g = self.comm(self.syn(x))
598
+ self.out.bind_cond(g)
599
+ return g