brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,599 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from typing import Optional, Union
18
-
19
- from brainstate._module import (Module, DelayAccess, Projection,
20
- ExtendedUpdateWithBA, ReceiveInputProj,
21
- register_delay_of_target)
22
- from brainstate._utils import set_module_as
23
- from brainstate.mixin import (DelayedInitializer, BindCondData, UpdateReturn, Mode, JointTypes)
24
- from ._utils import is_instance
25
-
26
- __all__ = [
27
- 'FullProjAlignPreSDMg', 'FullProjAlignPreDSMg',
28
- 'FullProjAlignPreSD', 'FullProjAlignPreDS',
29
- ]
30
-
31
-
32
- def align_pre_add_bef_update(
33
- syn_desc: DelayedInitializer,
34
- delay_at,
35
- delay_cls: ExtendedUpdateWithBA,
36
- proj_name: str = None
37
- ):
38
- _syn_id = f'Delay({str(delay_at)}) // {syn_desc.identifier}'
39
- if not delay_cls.has_before_update(_syn_id):
40
- # delay
41
- delay_access = DelayAccess(delay_cls, delay_at, delay_entry=proj_name)
42
- # synapse
43
- syn_cls = syn_desc()
44
- # add to "after_updates"
45
- delay_cls.add_before_update(_syn_id, _AlignPreMg(delay_access, syn_cls))
46
- syn = delay_cls.get_before_update(_syn_id).syn
47
- return syn
48
-
49
-
50
- class _AlignPreMg(Module):
51
- def __init__(self, access, syn):
52
- super().__init__()
53
- self.access = access
54
- self.syn = syn
55
-
56
- def update(self, *args, **kwargs):
57
- return self.syn(self.access())
58
-
59
-
60
- @set_module_as('brainstate.nn')
61
- class FullProjAlignPreSDMg(Projection):
62
- """Full-chain synaptic projection with the align-pre reduction and synapse+delay updating and merging.
63
-
64
- The ``full-chain`` means that the model needs to provide all information needed for a projection,
65
- including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
66
-
67
- The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
68
-
69
- The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the
70
- synapse states to the delay model, and finally computes the synaptic current.
71
-
72
- The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
73
- parameters (such like time constants) will also share the same synaptic variables.
74
-
75
- Neither ``FullProjAlignPreSDMg`` nor ``FullProjAlignPreDSMg`` facilitates the event-driven computation.
76
- This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
77
- than the spiking. To facilitate the event-driven computation, please use align post projections.
78
-
79
- To simulate an E/I balanced network model:
80
-
81
- .. code-block:: python
82
-
83
- class EINet(bp.DynSysGroup):
84
- def __init__(self):
85
- super().__init__()
86
- ne, ni = 3200, 800
87
- self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
88
- V_initializer=bp.init.Normal(-55., 2.))
89
- self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
90
- V_initializer=bp.init.Normal(-55., 2.))
91
- self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E,
92
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
93
- delay=0.1,
94
- comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
95
- out=bp.dyn.COBA(E=0.),
96
- post=self.E)
97
- self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E,
98
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
99
- delay=0.1,
100
- comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
101
- out=bp.dyn.COBA(E=0.),
102
- post=self.I)
103
- self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I,
104
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
105
- delay=0.1,
106
- comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
107
- out=bp.dyn.COBA(E=-80.),
108
- post=self.E)
109
- self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I,
110
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
111
- delay=0.1,
112
- comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
113
- out=bp.dyn.COBA(E=-80.),
114
- post=self.I)
115
-
116
- def update(self, inp):
117
- self.E2E()
118
- self.E2I()
119
- self.I2E()
120
- self.I2I()
121
- self.E(inp)
122
- self.I(inp)
123
- return self.E.spike
124
-
125
- model = EINet()
126
- indices = bm.arange(1000)
127
- spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
128
- bp.visualize.raster_plot(indices, spks, show=True)
129
-
130
-
131
- Args:
132
- pre: The pre-synaptic neuron group.
133
- syn: The synaptic dynamics.
134
- delay: The synaptic delay.
135
- comm: The synaptic communication.
136
- out: The synaptic output.
137
- post: The post-synaptic neuron group.
138
- name: str. The projection name.
139
- mode: Mode. The computing mode.
140
- """
141
-
142
- _invisible_nodes = ['pre', 'syn', 'delay', 'post']
143
-
144
- def __init__(
145
- self,
146
- pre: ExtendedUpdateWithBA,
147
- syn: DelayedInitializer[UpdateReturn],
148
- delay: Union[None, int, float],
149
- comm: Module,
150
- out: BindCondData,
151
- post: ReceiveInputProj,
152
- out_label: Optional[str] = None,
153
- name: Optional[str] = None,
154
- mode: Optional[Mode] = None,
155
- ):
156
- super().__init__(name=name, mode=mode)
157
-
158
- # synaptic models
159
- is_instance(pre, ExtendedUpdateWithBA)
160
- is_instance(syn, DelayedInitializer[UpdateReturn])
161
- is_instance(comm, Module)
162
- is_instance(out, BindCondData)
163
- is_instance(post, ReceiveInputProj)
164
- self.comm = comm
165
- self.out = out
166
-
167
- # synapse initialization
168
- if not pre.has_after_update(syn.identifier):
169
- syn_cls = syn()
170
- pre.add_after_update(syn.identifier, syn_cls)
171
- self.syn = pre.get_after_update(syn.identifier)
172
-
173
- # delay initialization
174
- if delay is not None and delay > 0.:
175
- delay_cls = register_delay_of_target(self.syn)
176
- self.has_delay = True
177
- self.delay = delay_cls
178
- else:
179
- self.has_delay = False
180
- self.delay = None
181
-
182
- # output initialization
183
- post.add_input_fun(self.name, out, label=out_label)
184
-
185
- # references
186
- self.pre = pre
187
- self.post = post
188
-
189
- def update(self, x=None):
190
- if x is None:
191
- if self.has_delay:
192
- x = self.delay.at(self.name)
193
- else:
194
- x = self.syn.update_return()
195
- current = self.comm(x)
196
- self.out.bind_cond(current)
197
- return current
198
-
199
-
200
- @set_module_as('brainstate.nn')
201
- class FullProjAlignPreDSMg(Projection):
202
- """Full-chain synaptic projection with the align-pre reduction and delay+synapse updating and merging.
203
-
204
- The ``full-chain`` means that the model needs to provide all information needed for a projection,
205
- including ``pre`` -> ``delay`` -> ``syn`` -> ``comm`` -> ``out`` -> ``post``.
206
- Note here, compared to ``FullProjAlignPreSDMg``, the ``delay`` and ``syn`` are exchanged.
207
-
208
- The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
209
-
210
- The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the
211
- spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current.
212
-
213
- The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
214
- parameters (such like time constants) will also share the same synaptic variables.
215
-
216
- Neither ``FullProjAlignPreDSMg`` nor ``FullProjAlignPreSDMg`` facilitates the event-driven computation.
217
- This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
218
- than the spiking. To facilitate the event-driven computation, please use align post projections.
219
-
220
-
221
- To simulate an E/I balanced network model:
222
-
223
- .. code-block:: python
224
-
225
- class EINet(bp.DynSysGroup):
226
- def __init__(self):
227
- super().__init__()
228
- ne, ni = 3200, 800
229
- self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
230
- V_initializer=bp.init.Normal(-55., 2.))
231
- self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
232
- V_initializer=bp.init.Normal(-55., 2.))
233
- self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E,
234
- delay=0.1,
235
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
236
- comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
237
- out=bp.dyn.COBA(E=0.),
238
- post=self.E)
239
- self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E,
240
- delay=0.1,
241
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
242
- comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
243
- out=bp.dyn.COBA(E=0.),
244
- post=self.I)
245
- self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I,
246
- delay=0.1,
247
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
248
- comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
249
- out=bp.dyn.COBA(E=-80.),
250
- post=self.E)
251
- self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I,
252
- delay=0.1,
253
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
254
- comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
255
- out=bp.dyn.COBA(E=-80.),
256
- post=self.I)
257
-
258
- def update(self, inp):
259
- self.E2E()
260
- self.E2I()
261
- self.I2E()
262
- self.I2I()
263
- self.E(inp)
264
- self.I(inp)
265
- return self.E.spike
266
-
267
- model = EINet()
268
- indices = bm.arange(1000)
269
- spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
270
- bp.visualize.raster_plot(indices, spks, show=True)
271
-
272
-
273
- Args:
274
- pre: The pre-synaptic neuron group.
275
- delay: The synaptic delay.
276
- syn: The synaptic dynamics.
277
- comm: The synaptic communication.
278
- out: The synaptic output.
279
- post: The post-synaptic neuron group.
280
- name: str. The projection name.
281
- mode: Mode. The computing mode.
282
- """
283
- _invisible_nodes = ['pre', 'syn', 'delay', 'post']
284
-
285
- def __init__(
286
- self,
287
- pre: JointTypes[ExtendedUpdateWithBA, UpdateReturn],
288
- delay: Union[None, int, float],
289
- syn: DelayedInitializer[UpdateReturn],
290
- comm: Module,
291
- out: BindCondData,
292
- post: ReceiveInputProj,
293
- out_label: Optional[str] = None,
294
- name: Optional[str] = None,
295
- mode: Optional[Mode] = None,
296
- ):
297
- super().__init__(name=name, mode=mode)
298
-
299
- # synaptic models
300
- is_instance(pre, JointTypes[ExtendedUpdateWithBA, UpdateReturn])
301
- is_instance(syn, DelayedInitializer[Module])
302
- is_instance(comm, Module)
303
- is_instance(out, BindCondData)
304
- is_instance(post, ReceiveInputProj)
305
- self.comm = comm
306
- self.out = out
307
-
308
- # delay initialization
309
- if delay is not None and delay > 0.:
310
- delay_cls = register_delay_of_target(pre)
311
- self.has_delay = True
312
- self.delay = delay_cls
313
- # synapse initialization
314
- self.syn = align_pre_add_bef_update(syn, delay, delay_cls, self.name)
315
- else:
316
- self.has_delay = False
317
- self.delay = None
318
- if not pre.has_after_update(syn.identifier):
319
- syn_cls = syn()
320
- pre.add_after_update(syn.identifier, syn_cls)
321
- self.syn = pre.get_after_update(syn.identifier)
322
-
323
- # output initialization
324
- post.add_input_fun(self.name, out, label=out_label)
325
-
326
- # references
327
- self.pre = pre
328
- self.post = post
329
-
330
- def update(self):
331
- x = self.syn.update_return()
332
- current = self.comm(x)
333
- self.out.bind_cond(current)
334
- return current
335
-
336
-
337
- @set_module_as('brainstate.nn')
338
- class FullProjAlignPreSD(Projection):
339
- """Full-chain synaptic projection with the align-pre reduction and synapse+delay updating.
340
-
341
- The ``full-chain`` means that the model needs to provide all information needed for a projection,
342
- including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
343
-
344
- The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
345
-
346
- The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the
347
- synapse states to the delay model, and finally computes the synaptic current.
348
-
349
- Neither ``FullProjAlignPreSD`` nor ``FullProjAlignPreDS`` facilitates the event-driven computation.
350
- This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
351
- than the spiking. To facilitate the event-driven computation, please use align post projections.
352
-
353
-
354
- To simulate an E/I balanced network model:
355
-
356
- .. code-block:: python
357
-
358
- class EINet(bp.DynSysGroup):
359
- def __init__(self):
360
- super().__init__()
361
- ne, ni = 3200, 800
362
- self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
363
- V_initializer=bp.init.Normal(-55., 2.))
364
- self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
365
- V_initializer=bp.init.Normal(-55., 2.))
366
- self.E2E = bp.dyn.FullProjAlignPreSD(pre=self.E,
367
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
368
- delay=0.1,
369
- comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
370
- out=bp.dyn.COBA(E=0.),
371
- post=self.E)
372
- self.E2I = bp.dyn.FullProjAlignPreSD(pre=self.E,
373
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
374
- delay=0.1,
375
- comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
376
- out=bp.dyn.COBA(E=0.),
377
- post=self.I)
378
- self.I2E = bp.dyn.FullProjAlignPreSD(pre=self.I,
379
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
380
- delay=0.1,
381
- comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
382
- out=bp.dyn.COBA(E=-80.),
383
- post=self.E)
384
- self.I2I = bp.dyn.FullProjAlignPreSD(pre=self.I,
385
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
386
- delay=0.1,
387
- comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
388
- out=bp.dyn.COBA(E=-80.),
389
- post=self.I)
390
-
391
- def update(self, inp):
392
- self.E2E()
393
- self.E2I()
394
- self.I2E()
395
- self.I2I()
396
- self.E(inp)
397
- self.I(inp)
398
- return self.E.spike
399
-
400
- model = EINet()
401
- indices = bm.arange(1000)
402
- spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
403
- bp.visualize.raster_plot(indices, spks, show=True)
404
-
405
-
406
- Args:
407
- pre: The pre-synaptic neuron group.
408
- syn: The synaptic dynamics.
409
- delay: The synaptic delay.
410
- comm: The synaptic communication.
411
- out: The synaptic output.
412
- post: The post-synaptic neuron group.
413
- name: str. The projection name.
414
- mode: Mode. The computing mode.
415
- """
416
-
417
- _invisible_nodes = ['pre', 'post']
418
-
419
- def __init__(
420
- self,
421
- pre: ExtendedUpdateWithBA,
422
- syn: UpdateReturn,
423
- delay: Union[None, int, float],
424
- comm: Module,
425
- out: BindCondData,
426
- post: ReceiveInputProj,
427
- out_label: Optional[str] = None,
428
- name: Optional[str] = None,
429
- mode: Optional[Mode] = None,
430
- ):
431
- super().__init__(name=name, mode=mode)
432
-
433
- # synaptic models
434
- is_instance(pre, ExtendedUpdateWithBA)
435
- is_instance(syn, UpdateReturn)
436
- is_instance(comm, Module)
437
- is_instance(out, BindCondData)
438
- is_instance(post, ReceiveInputProj)
439
- self.comm = comm
440
- self.syn = syn
441
- self.out = out
442
-
443
- # delay initialization
444
- if delay is not None and delay > 0.:
445
- delay_cls = register_delay_of_target(syn)
446
- delay_cls.register_entry(self.name, delay)
447
- self.delay = delay_cls
448
- else:
449
- self.delay = None
450
-
451
- # output initialization
452
- post.add_input_fun(self.name, out, label=out_label)
453
-
454
- # references
455
- self.pre = pre
456
- self.post = post
457
-
458
- def update(self):
459
- if self.delay is not None:
460
- self.delay(self.syn(self.pre.update_return()))
461
- x = self.delay.at(self.name)
462
- else:
463
- x = self.syn(self.pre.update_return())
464
- current = self.comm(x)
465
- self.out.bind_cond(current)
466
- return current
467
-
468
-
469
- @set_module_as('brainstate.nn')
470
- class FullProjAlignPreDS(Projection):
471
- """Full-chain synaptic projection with the align-pre reduction and delay+synapse updating.
472
-
473
- The ``full-chain`` means that the model needs to provide all information needed for a projection,
474
- including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
475
- Note here, compared to ``FullProjAlignPreSD``, the ``delay`` and ``syn`` are exchanged.
476
-
477
- The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
478
-
479
- The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the
480
- spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current.
481
-
482
- Neither ``FullProjAlignPreDS`` nor ``FullProjAlignPreSD`` facilitates the event-driven computation.
483
- This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
484
- than the spiking. To facilitate the event-driven computation, please use align post projections.
485
-
486
-
487
- To simulate an E/I balanced network model:
488
-
489
- .. code-block:: python
490
-
491
- class EINet(bp.DynSysGroup):
492
- def __init__(self):
493
- super().__init__()
494
- ne, ni = 3200, 800
495
- self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
496
- V_initializer=bp.init.Normal(-55., 2.))
497
- self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
498
- V_initializer=bp.init.Normal(-55., 2.))
499
- self.E2E = bp.dyn.FullProjAlignPreDS(pre=self.E,
500
- delay=0.1,
501
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
502
- comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
503
- out=bp.dyn.COBA(E=0.),
504
- post=self.E)
505
- self.E2I = bp.dyn.FullProjAlignPreDS(pre=self.E,
506
- delay=0.1,
507
- syn=bp.dyn.Expon.desc(size=ne, tau=5.),
508
- comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
509
- out=bp.dyn.COBA(E=0.),
510
- post=self.I)
511
- self.I2E = bp.dyn.FullProjAlignPreDS(pre=self.I,
512
- delay=0.1,
513
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
514
- comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
515
- out=bp.dyn.COBA(E=-80.),
516
- post=self.E)
517
- self.I2I = bp.dyn.FullProjAlignPreDS(pre=self.I,
518
- delay=0.1,
519
- syn=bp.dyn.Expon.desc(size=ni, tau=10.),
520
- comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
521
- out=bp.dyn.COBA(E=-80.),
522
- post=self.I)
523
-
524
- def update(self, inp):
525
- self.E2E()
526
- self.E2I()
527
- self.I2E()
528
- self.I2I()
529
- self.E(inp)
530
- self.I(inp)
531
- return self.E.spike
532
-
533
- model = EINet()
534
- indices = bm.arange(1000)
535
- spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
536
- bp.visualize.raster_plot(indices, spks, show=True)
537
-
538
-
539
- Args:
540
- pre: The pre-synaptic neuron group.
541
- delay: The synaptic delay.
542
- syn: The synaptic dynamics.
543
- comm: The synaptic communication.
544
- out: The synaptic output.
545
- post: The post-synaptic neuron group.
546
- name: str. The projection name.
547
- mode: Mode. The computing mode.
548
- """
549
-
550
- _invisible_nodes = ['pre', 'post', 'delay']
551
-
552
- def __init__(
553
- self,
554
- pre: UpdateReturn,
555
- delay: Union[None, int, float],
556
- syn: Module,
557
- comm: Module,
558
- out: BindCondData,
559
- post: ReceiveInputProj,
560
- out_label: Optional[str] = None,
561
- name: Optional[str] = None,
562
- mode: Optional[Mode] = None,
563
- ):
564
- super().__init__(name=name, mode=mode)
565
-
566
- # synaptic models
567
- is_instance(pre, UpdateReturn)
568
- is_instance(syn, Module)
569
- is_instance(comm, Module)
570
- is_instance(out, BindCondData)
571
- is_instance(post, ReceiveInputProj)
572
- self.comm = comm
573
- self.syn = syn
574
- self.out = out
575
-
576
- # delay initialization
577
- if delay is not None and delay > 0.:
578
- delay_cls = register_delay_of_target(pre)
579
- delay_cls.register_entry(self.name, delay)
580
- self.delay = delay_cls
581
- else:
582
- self.delay = None
583
-
584
- # output initialization
585
- post.add_input_fun(self.name, out, label=out_label)
586
-
587
- # references
588
- self.pre = pre
589
- self.post = post
590
-
591
- def update(self, x=None):
592
- if x is None:
593
- if self.delay is not None:
594
- x = self.delay.at(self.name)
595
- else:
596
- x = self.pre.update_return()
597
- g = self.comm(self.syn(x))
598
- self.out.bind_cond(g)
599
- return g