brainstate 0.1.8__py2.py3-none-any.whl → 0.1.10__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -156
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +588 -502
  58. brainstate/nn/_delay_test.py +238 -184
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1343
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1119
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -40
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/METADATA +91 -99
  127. brainstate-0.1.10.dist-info/RECORD +130 -0
  128. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.8.dist-info/RECORD +0 -132
  133. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/top_level.txt +0 -0
@@ -1,486 +1,486 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import Callable, Union
17
- from typing import Optional
18
-
19
- import brainevent
20
- import brainunit as u
21
-
22
- from brainstate._state import State
23
- from brainstate.mixin import BindCondData, JointTypes
24
- from brainstate.mixin import ParamDescriber, AlignPost
25
- from brainstate.util.others import get_unique_name
26
- from ._collective_ops import call_order
27
- from ._dynamics import Dynamics, Projection, maybe_init_prefetch, Prefetch, PrefetchDelayAt
28
- from ._module import Module
29
- from ._stp import ShortTermPlasticity
30
- from ._synapse import Synapse
31
- from ._synouts import SynOut
32
-
33
- __all__ = [
34
- 'AlignPostProj',
35
- 'DeltaProj',
36
- 'CurrentProj',
37
-
38
- 'align_pre_projection',
39
- 'align_post_projection',
40
- ]
41
-
42
-
43
- def _check_modules(*modules):
44
- # checking modules
45
- for module in modules:
46
- if not callable(module) and not isinstance(module, State):
47
- raise TypeError(
48
- f'The module should be a callable function or a brainstate.State, but got {module}.'
49
- )
50
- return tuple(modules)
51
-
52
-
53
- def call_module(module, *args, **kwargs):
54
- if callable(module):
55
- return module(*args, **kwargs)
56
- elif isinstance(module, State):
57
- return module.value
58
- else:
59
- raise TypeError(
60
- f'The module should be a callable function or a brainstate.State, but got {module}.'
61
- )
62
-
63
-
64
- def is_instance(x, cls) -> bool:
65
- return isinstance(x, cls)
66
-
67
-
68
- def get_post_repr(label, syn, out):
69
- if label is None:
70
- return f'{syn.identifier} // {out.identifier}'
71
- else:
72
- return f'{label}{syn.identifier} // {out.identifier}'
73
-
74
-
75
- def align_post_add_bef_update(
76
- syn_desc: ParamDescriber[AlignPost],
77
- out_desc: ParamDescriber[BindCondData],
78
- post: Dynamics,
79
- proj_name: str,
80
- label: str,
81
- ):
82
- # synapse and output initialization
83
- _post_repr = get_post_repr(label, syn_desc, out_desc)
84
- if not post._has_before_update(_post_repr):
85
- syn_cls = syn_desc()
86
- out_cls = out_desc()
87
-
88
- # synapse and output initialization
89
- post.add_current_input(proj_name, out_cls, label=label)
90
- post._add_before_update(_post_repr, _AlignPost(syn_cls, out_cls))
91
- syn = post._get_before_update(_post_repr).syn
92
- out = post._get_before_update(_post_repr).out
93
- return syn, out
94
-
95
-
96
- class _AlignPost(Module):
97
- def __init__(
98
- self,
99
- syn: Dynamics,
100
- out: BindCondData
101
- ):
102
- super().__init__()
103
- self.syn = syn
104
- self.out = out
105
-
106
- def update(self, *args, **kwargs):
107
- self.out.bind_cond(self.syn(*args, **kwargs))
108
-
109
-
110
- class AlignPostProj(Projection):
111
- """
112
- Align-post projection of the neural network.
113
-
114
-
115
- Examples
116
- --------
117
-
118
- Here is an example of using the `AlignPostProj` to create a synaptic projection.
119
- Note that this projection needs the manual input of pre-synaptic spikes.
120
-
121
- >>> import brainstate
122
- >>> import brainunit as u
123
- >>> n_exc = 3200
124
- >>> n_inh = 800
125
- >>> num = n_exc + n_inh
126
- >>> pop = brainstate.nn.LIFRef(
127
- ... num,
128
- ... V_rest=-49. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
129
- ... tau=20. * u.ms, tau_ref=5. * u.ms,
130
- ... V_initializer=brainstate.init.Normal(-55., 2., unit=u.mV)
131
- ... )
132
- >>> pop.init_state()
133
- >>> E = brainstate.nn.AlignPostProj(
134
- ... comm=brainstate.nn.FixedNumConn(n_exc, num, prob=80 / num, weight=1.62 * u.mS),
135
- ... syn=brainstate.nn.Expon.desc(num, tau=5. * u.ms),
136
- ... out=brainstate.nn.CUBA.desc(scale=u.volt),
137
- ... post=pop
138
- ... )
139
- >>> exe_current = E(pop.get_spike())
140
-
141
- """
142
- __module__ = 'brainstate.nn'
143
-
144
- def __init__(
145
- self,
146
- *modules,
147
- comm: Callable,
148
- syn: Union[ParamDescriber[AlignPost], AlignPost],
149
- out: Union[ParamDescriber[SynOut], SynOut],
150
- post: Dynamics,
151
- label: Optional[str] = None,
152
- ):
153
- super().__init__(name=get_unique_name(self.__class__.__name__))
154
-
155
- # checking modules
156
- self.modules = _check_modules(*modules)
157
-
158
- # checking communication model
159
- if not callable(comm):
160
- raise TypeError(
161
- f'The communication should be an instance of callable function, but got {comm}.'
162
- )
163
-
164
- # checking synapse and output models
165
- if is_instance(syn, ParamDescriber[AlignPost]):
166
- if not is_instance(out, ParamDescriber[SynOut]):
167
- if is_instance(out, ParamDescriber):
168
- raise TypeError(
169
- f'The output should be an instance of describer {ParamDescriber[SynOut]} when '
170
- f'the synapse is an instance of {AlignPost}, but got {out}.'
171
- )
172
- raise TypeError(
173
- f'The output should be an instance of describer {ParamDescriber[SynOut]} when '
174
- f'the synapse is a describer, but we got {out}.'
175
- )
176
- merging = True
177
- else:
178
- if is_instance(syn, ParamDescriber):
179
- raise TypeError(
180
- f'The synapse should be an instance of describer {ParamDescriber[AlignPost]}, but got {syn}.'
181
- )
182
- if not is_instance(out, SynOut):
183
- raise TypeError(
184
- f'The output should be an instance of {SynOut} when the synapse is '
185
- f'not a describer, but we got {out}.'
186
- )
187
- merging = False
188
- self.merging = merging
189
-
190
- # checking post model
191
- if not is_instance(post, Dynamics):
192
- raise TypeError(
193
- f'The post should be an instance of {Dynamics}, but got {post}.'
194
- )
195
-
196
- if merging:
197
- # synapse and output initialization
198
- syn, out = align_post_add_bef_update(syn_desc=syn,
199
- out_desc=out,
200
- post=post,
201
- proj_name=self.name,
202
- label=label)
203
- else:
204
- post.add_current_input(self.name, out)
205
-
206
- # references
207
- self.comm = comm
208
- self.syn: JointTypes[Dynamics, AlignPost] = syn
209
- self.out: BindCondData = out
210
- self.post: Dynamics = post
211
-
212
- @call_order(2)
213
- def init_state(self, *args, **kwargs):
214
- for module in self.modules:
215
- maybe_init_prefetch(module, *args, **kwargs)
216
-
217
- def update(self, *args):
218
- # call all modules
219
- for module in self.modules:
220
- x = call_module(module, *args)
221
- args = (x,)
222
- # communication module
223
- x = self.comm(*args)
224
- # add synapse input
225
- self.syn.add_delta_input(self.name, x)
226
- if not self.merging:
227
- # synapse and output interaction
228
- conductance = self.syn()
229
- self.out.bind_cond(conductance)
230
-
231
-
232
- class DeltaProj(Projection):
233
- """
234
- Delta-based projection of the neural network.
235
-
236
- This projection directly applies delta inputs to post-synaptic neurons without intervening
237
- synaptic dynamics. It processes inputs through optional prefetch modules, applies a communication model,
238
- and adds the result directly as a delta input to the post-synaptic population.
239
-
240
- Parameters
241
- ----------
242
- *prefetch : State or callable
243
- Optional prefetch modules to process input before communication.
244
- comm : callable
245
- Communication model that determines how signals are transmitted.
246
- post : Dynamics
247
- Post-synaptic neural population to receive the delta inputs.
248
- label : Optional[str], default=None
249
- Optional label for the projection to identify it in the post-synaptic population.
250
-
251
- Examples
252
- --------
253
- >>> import brainstate
254
- >>> import brainunit as u
255
- >>> n_neurons = 100
256
- >>> pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
257
- >>> pop.init_state()
258
- >>> delta_input = brainstate.nn.DeltaProj(
259
- ... comm=lambda x: x * 10.0*u.mV,
260
- ... post=pop
261
- ... )
262
- >>> delta_input(1.0) # Apply voltage increment directly
263
- """
264
- __module__ = 'brainstate.nn'
265
-
266
- def __init__(self, *prefetch, comm: Callable, post: Dynamics, label=None):
267
- super().__init__(name=get_unique_name(self.__class__.__name__))
268
-
269
- self.label = label
270
-
271
- # checking modules
272
- self.prefetches = _check_modules(*prefetch)
273
-
274
- # checking communication model
275
- if not callable(comm):
276
- raise TypeError(
277
- f'The communication should be an instance of callable function, but got {comm}.'
278
- )
279
- self.comm = comm
280
-
281
- # post model
282
- if not isinstance(post, Dynamics):
283
- raise TypeError(
284
- f'The post should be an instance of {Dynamics}, but got {post}.'
285
- )
286
- self.post = post
287
-
288
- @call_order(2)
289
- def init_state(self, *args, **kwargs):
290
- for prefetch in self.prefetches:
291
- maybe_init_prefetch(prefetch, *args, **kwargs)
292
-
293
- def update(self, *x):
294
- for module in self.prefetches:
295
- x = (call_module(module, *x),)
296
- assert len(x) == 1, f'The output of the modules should be a single value, but got {x}.'
297
- x = self.comm(x[0])
298
- self.post.add_delta_input(self.name, x, label=self.label)
299
-
300
-
301
- class CurrentProj(Projection):
302
- """
303
- Current-based projection of the neural network.
304
-
305
- This projection directly modulates post-synaptic currents without separate synaptic dynamics.
306
- It processes inputs through optional prefetch modules, applies a communication model,
307
- and binds the result to the output model which is then added as a current input to the post-synaptic population.
308
-
309
- Parameters
310
- ----------
311
- *prefetch : State or callable
312
- Optional prefetch modules to process input before communication.
313
- The last element must be an instance of Prefetch or PrefetchDelayAt if any are provided.
314
- comm : callable
315
- Communication model that determines how signals are transmitted.
316
- out : SynOut
317
- Output model that converts communication results to post-synaptic currents.
318
- post : Dynamics
319
- Post-synaptic neural population to receive the currents.
320
-
321
- Examples
322
- --------
323
- >>> import brainstate
324
- >>> import brainunit as u
325
- >>> n_neurons = 100
326
- >>> pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
327
- >>> pop.init_state()
328
- >>> current_input = brainstate.nn.CurrentProj(
329
- ... comm=lambda x: x * 0.5,
330
- ... out=brainstate.nn.CUBA(scale=1.0*u.nA),
331
- ... post=pop
332
- ... )
333
- >>> current_input(0.2) # Apply external current
334
- """
335
- __module__ = 'brainstate.nn'
336
-
337
- def __init__(
338
- self,
339
- *prefetch,
340
- comm: Callable,
341
- out: SynOut,
342
- post: Dynamics,
343
- ):
344
- super().__init__(name=get_unique_name(self.__class__.__name__))
345
-
346
- # check prefetch
347
- self.prefetch = prefetch
348
- if len(self.prefetch) > 0 and not isinstance(prefetch[-1], (Prefetch, PrefetchDelayAt)):
349
- raise TypeError(
350
- f'The last element of prefetch should be an instance of {Prefetch} or {PrefetchDelayAt}, '
351
- f'but got {prefetch[-1]}.'
352
- )
353
-
354
- # check out
355
- if not isinstance(out, SynOut):
356
- raise TypeError(f'The out should be a SynOut, but got {out}.')
357
- self.out = out
358
-
359
- # check post
360
- if not isinstance(post, Dynamics):
361
- raise TypeError(f'The post should be a Dynamics, but got {post}.')
362
- self.post = post
363
- post.add_current_input(self.name, out)
364
-
365
- # output initialization
366
- self.comm = comm
367
-
368
- @call_order(2)
369
- def init_state(self, *args, **kwargs):
370
- for prefetch in self.prefetch:
371
- maybe_init_prefetch(prefetch, *args, **kwargs)
372
-
373
- def update(self, *x):
374
- for prefetch in self.prefetch:
375
- x = (call_module(prefetch, *x),)
376
- x = self.comm(*x)
377
- self.out.bind_cond(x)
378
-
379
-
380
- class align_pre_projection(Projection):
381
- """
382
- Represents a pre-synaptic alignment projection mechanism.
383
-
384
- This class inherits from the `Projection` base class and is designed to
385
- manage the pre-synaptic alignment process in neural network simulations.
386
- It takes into account pre-synaptic dynamics, synaptic properties, delays,
387
- communication functions, synaptic outputs, post-synaptic dynamics, and
388
- short-term plasticity.
389
-
390
- Attributes:
391
- pre (Dynamics): The pre-synaptic dynamics object.
392
- syn (Synapse): The synaptic object after pre-synaptic alignment.
393
- delay (u.Quantity[u.second]): The output delay from the synapse.
394
- projection (CurrentProj): The current projection object handling communication,
395
- output, and post-synaptic dynamics.
396
- stp (ShortTermPlasticity, optional): The short-term plasticity object,
397
- defaults to None.
398
- """
399
-
400
- def __init__(
401
- self,
402
- *spike_generator,
403
- syn: Dynamics,
404
- comm: Callable,
405
- out: SynOut,
406
- post: Dynamics,
407
- stp: ShortTermPlasticity = None,
408
- ):
409
- super().__init__()
410
-
411
- self.spike_generator = _check_modules(*spike_generator)
412
- self.projection = CurrentProj(comm=comm, out=out, post=post)
413
- self.syn = syn
414
- self.stp = stp
415
-
416
- @call_order(2)
417
- def init_state(self, *args, **kwargs):
418
- for module in self.spike_generator:
419
- maybe_init_prefetch(module, *args, **kwargs)
420
-
421
- def update(self, *x):
422
- for fun in self.spike_generator:
423
- x = fun(*x)
424
- if isinstance(x, (tuple, list)):
425
- x = tuple(x)
426
- else:
427
- x = (x,)
428
- assert len(x) == 1, "Spike generator must return a single value or a tuple/list of values"
429
- x = brainevent.BinaryArray(x[0]) # Ensure input is a BinaryFloat for spike generation
430
- if self.stp is not None:
431
- x = brainevent.MaskedFloat(self.stp(x)) # Ensure STP output is a MaskedFloat
432
- x = self.syn(x) # Apply pre-synaptic alignment
433
- return self.projection(x)
434
-
435
-
436
- class align_post_projection(Projection):
437
- """
438
- Represents a post-synaptic alignment projection mechanism.
439
-
440
- This class inherits from the `Projection` base class and is designed to
441
- manage the post-synaptic alignment process in neural network simulations.
442
- It takes into account spike generators, communication functions, synaptic
443
- properties, synaptic outputs, post-synaptic dynamics, and short-term plasticity.
444
-
445
- Args:
446
- *spike_generator: Callable(s) that generate spike events or transform input spikes.
447
- comm (Callable): Communication function for the projection.
448
- syn (Union[AlignPost, ParamDescriber[AlignPost]]): The post-synaptic alignment object or its parameter describer.
449
- out (Union[SynOut, ParamDescriber[SynOut]]): The synaptic output object or its parameter describer.
450
- post (Dynamics): The post-synaptic dynamics object.
451
- stp (ShortTermPlasticity, optional): The short-term plasticity object, defaults to None.
452
-
453
- """
454
-
455
- def __init__(
456
- self,
457
- *spike_generator,
458
- comm: Callable,
459
- syn: Union[AlignPost, ParamDescriber[AlignPost]],
460
- out: Union[SynOut, ParamDescriber[SynOut]],
461
- post: Dynamics,
462
- stp: ShortTermPlasticity = None,
463
- ):
464
- super().__init__()
465
-
466
- self.spike_generator = _check_modules(*spike_generator)
467
- self.projection = AlignPostProj(comm=comm, syn=syn, out=out, post=post)
468
- self.stp = stp
469
-
470
- @call_order(2)
471
- def init_state(self, *args, **kwargs):
472
- for module in self.spike_generator:
473
- maybe_init_prefetch(module, *args, **kwargs)
474
-
475
- def update(self, *x):
476
- for fun in self.spike_generator:
477
- x = fun(*x)
478
- if isinstance(x, (tuple, list)):
479
- x = tuple(x)
480
- else:
481
- x = (x,)
482
- assert len(x) == 1, "Spike generator must return a single value or a tuple/list of values"
483
- x = brainevent.BinaryArray(x[0]) # Ensure input is a BinaryFloat for spike generation
484
- if self.stp is not None:
485
- x = brainevent.MaskedFloat(self.stp(x)) # Ensure STP output is a MaskedFloat
486
- return self.projection(x)
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Callable, Union
17
+ from typing import Optional
18
+
19
+ import brainevent
20
+ import brainunit as u
21
+
22
+ from brainstate._state import State
23
+ from brainstate.mixin import BindCondData, JointTypes
24
+ from brainstate.mixin import ParamDescriber, AlignPost
25
+ from brainstate.util.others import get_unique_name
26
+ from ._collective_ops import call_order
27
+ from ._dynamics import Dynamics, Projection, maybe_init_prefetch, Prefetch, PrefetchDelayAt
28
+ from ._module import Module
29
+ from ._stp import ShortTermPlasticity
30
+ from ._synapse import Synapse
31
+ from ._synouts import SynOut
32
+
33
+ __all__ = [
34
+ 'AlignPostProj',
35
+ 'DeltaProj',
36
+ 'CurrentProj',
37
+
38
+ 'align_pre_projection',
39
+ 'align_post_projection',
40
+ ]
41
+
42
+
43
+ def _check_modules(*modules):
44
+ # checking modules
45
+ for module in modules:
46
+ if not callable(module) and not isinstance(module, State):
47
+ raise TypeError(
48
+ f'The module should be a callable function or a brainstate.State, but got {module}.'
49
+ )
50
+ return tuple(modules)
51
+
52
+
53
+ def call_module(module, *args, **kwargs):
54
+ if callable(module):
55
+ return module(*args, **kwargs)
56
+ elif isinstance(module, State):
57
+ return module.value
58
+ else:
59
+ raise TypeError(
60
+ f'The module should be a callable function or a brainstate.State, but got {module}.'
61
+ )
62
+
63
+
64
+ def is_instance(x, cls) -> bool:
65
+ return isinstance(x, cls)
66
+
67
+
68
+ def get_post_repr(label, syn, out):
69
+ if label is None:
70
+ return f'{syn.identifier} // {out.identifier}'
71
+ else:
72
+ return f'{label}{syn.identifier} // {out.identifier}'
73
+
74
+
75
+ def align_post_add_bef_update(
76
+ syn_desc: ParamDescriber[AlignPost],
77
+ out_desc: ParamDescriber[BindCondData],
78
+ post: Dynamics,
79
+ proj_name: str,
80
+ label: str,
81
+ ):
82
+ # synapse and output initialization
83
+ _post_repr = get_post_repr(label, syn_desc, out_desc)
84
+ if not post._has_before_update(_post_repr):
85
+ syn_cls = syn_desc()
86
+ out_cls = out_desc()
87
+
88
+ # synapse and output initialization
89
+ post.add_current_input(proj_name, out_cls, label=label)
90
+ post._add_before_update(_post_repr, _AlignPost(syn_cls, out_cls))
91
+ syn = post._get_before_update(_post_repr).syn
92
+ out = post._get_before_update(_post_repr).out
93
+ return syn, out
94
+
95
+
96
+ class _AlignPost(Module):
97
+ def __init__(
98
+ self,
99
+ syn: Dynamics,
100
+ out: BindCondData
101
+ ):
102
+ super().__init__()
103
+ self.syn = syn
104
+ self.out = out
105
+
106
+ def update(self, *args, **kwargs):
107
+ self.out.bind_cond(self.syn(*args, **kwargs))
108
+
109
+
110
+ class AlignPostProj(Projection):
111
+ """
112
+ Align-post projection of the neural network.
113
+
114
+
115
+ Examples
116
+ --------
117
+
118
+ Here is an example of using the `AlignPostProj` to create a synaptic projection.
119
+ Note that this projection needs the manual input of pre-synaptic spikes.
120
+
121
+ >>> import brainstate
122
+ >>> import brainunit as u
123
+ >>> n_exc = 3200
124
+ >>> n_inh = 800
125
+ >>> num = n_exc + n_inh
126
+ >>> pop = brainstate.nn.LIFRef(
127
+ ... num,
128
+ ... V_rest=-49. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
129
+ ... tau=20. * u.ms, tau_ref=5. * u.ms,
130
+ ... V_initializer=brainstate.init.Normal(-55., 2., unit=u.mV)
131
+ ... )
132
+ >>> pop.init_state()
133
+ >>> E = brainstate.nn.AlignPostProj(
134
+ ... comm=brainstate.nn.FixedNumConn(n_exc, num, prob=80 / num, weight=1.62 * u.mS),
135
+ ... syn=brainstate.nn.Expon.desc(num, tau=5. * u.ms),
136
+ ... out=brainstate.nn.CUBA.desc(scale=u.volt),
137
+ ... post=pop
138
+ ... )
139
+ >>> exe_current = E(pop.get_spike())
140
+
141
+ """
142
+ __module__ = 'brainstate.nn'
143
+
144
+ def __init__(
145
+ self,
146
+ *modules,
147
+ comm: Callable,
148
+ syn: Union[ParamDescriber[AlignPost], AlignPost],
149
+ out: Union[ParamDescriber[SynOut], SynOut],
150
+ post: Dynamics,
151
+ label: Optional[str] = None,
152
+ ):
153
+ super().__init__(name=get_unique_name(self.__class__.__name__))
154
+
155
+ # checking modules
156
+ self.modules = _check_modules(*modules)
157
+
158
+ # checking communication model
159
+ if not callable(comm):
160
+ raise TypeError(
161
+ f'The communication should be an instance of callable function, but got {comm}.'
162
+ )
163
+
164
+ # checking synapse and output models
165
+ if is_instance(syn, ParamDescriber[AlignPost]):
166
+ if not is_instance(out, ParamDescriber[SynOut]):
167
+ if is_instance(out, ParamDescriber):
168
+ raise TypeError(
169
+ f'The output should be an instance of describer {ParamDescriber[SynOut]} when '
170
+ f'the synapse is an instance of {AlignPost}, but got {out}.'
171
+ )
172
+ raise TypeError(
173
+ f'The output should be an instance of describer {ParamDescriber[SynOut]} when '
174
+ f'the synapse is a describer, but we got {out}.'
175
+ )
176
+ merging = True
177
+ else:
178
+ if is_instance(syn, ParamDescriber):
179
+ raise TypeError(
180
+ f'The synapse should be an instance of describer {ParamDescriber[AlignPost]}, but got {syn}.'
181
+ )
182
+ if not is_instance(out, SynOut):
183
+ raise TypeError(
184
+ f'The output should be an instance of {SynOut} when the synapse is '
185
+ f'not a describer, but we got {out}.'
186
+ )
187
+ merging = False
188
+ self.merging = merging
189
+
190
+ # checking post model
191
+ if not is_instance(post, Dynamics):
192
+ raise TypeError(
193
+ f'The post should be an instance of {Dynamics}, but got {post}.'
194
+ )
195
+
196
+ if merging:
197
+ # synapse and output initialization
198
+ syn, out = align_post_add_bef_update(syn_desc=syn,
199
+ out_desc=out,
200
+ post=post,
201
+ proj_name=self.name,
202
+ label=label)
203
+ else:
204
+ post.add_current_input(self.name, out)
205
+
206
+ # references
207
+ self.comm = comm
208
+ self.syn: JointTypes[Dynamics, AlignPost] = syn
209
+ self.out: BindCondData = out
210
+ self.post: Dynamics = post
211
+
212
+ @call_order(2)
213
+ def init_state(self, *args, **kwargs):
214
+ for module in self.modules:
215
+ maybe_init_prefetch(module, *args, **kwargs)
216
+
217
+ def update(self, *args):
218
+ # call all modules
219
+ for module in self.modules:
220
+ x = call_module(module, *args)
221
+ args = (x,)
222
+ # communication module
223
+ x = self.comm(*args)
224
+ # add synapse input
225
+ self.syn.add_delta_input(self.name, x)
226
+ if not self.merging:
227
+ # synapse and output interaction
228
+ conductance = self.syn()
229
+ self.out.bind_cond(conductance)
230
+
231
+
232
+ class DeltaProj(Projection):
233
+ """
234
+ Delta-based projection of the neural network.
235
+
236
+ This projection directly applies delta inputs to post-synaptic neurons without intervening
237
+ synaptic dynamics. It processes inputs through optional prefetch modules, applies a communication model,
238
+ and adds the result directly as a delta input to the post-synaptic population.
239
+
240
+ Parameters
241
+ ----------
242
+ *prefetch : State or callable
243
+ Optional prefetch modules to process input before communication.
244
+ comm : callable
245
+ Communication model that determines how signals are transmitted.
246
+ post : Dynamics
247
+ Post-synaptic neural population to receive the delta inputs.
248
+ label : Optional[str], default=None
249
+ Optional label for the projection to identify it in the post-synaptic population.
250
+
251
+ Examples
252
+ --------
253
+ >>> import brainstate
254
+ >>> import brainunit as u
255
+ >>> n_neurons = 100
256
+ >>> pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
257
+ >>> pop.init_state()
258
+ >>> delta_input = brainstate.nn.DeltaProj(
259
+ ... comm=lambda x: x * 10.0*u.mV,
260
+ ... post=pop
261
+ ... )
262
+ >>> delta_input(1.0) # Apply voltage increment directly
263
+ """
264
+ __module__ = 'brainstate.nn'
265
+
266
+ def __init__(self, *prefetch, comm: Callable, post: Dynamics, label=None):
267
+ super().__init__(name=get_unique_name(self.__class__.__name__))
268
+
269
+ self.label = label
270
+
271
+ # checking modules
272
+ self.prefetches = _check_modules(*prefetch)
273
+
274
+ # checking communication model
275
+ if not callable(comm):
276
+ raise TypeError(
277
+ f'The communication should be an instance of callable function, but got {comm}.'
278
+ )
279
+ self.comm = comm
280
+
281
+ # post model
282
+ if not isinstance(post, Dynamics):
283
+ raise TypeError(
284
+ f'The post should be an instance of {Dynamics}, but got {post}.'
285
+ )
286
+ self.post = post
287
+
288
+ @call_order(2)
289
+ def init_state(self, *args, **kwargs):
290
+ for prefetch in self.prefetches:
291
+ maybe_init_prefetch(prefetch, *args, **kwargs)
292
+
293
+ def update(self, *x):
294
+ for module in self.prefetches:
295
+ x = (call_module(module, *x),)
296
+ assert len(x) == 1, f'The output of the modules should be a single value, but got {x}.'
297
+ x = self.comm(x[0])
298
+ self.post.add_delta_input(self.name, x, label=self.label)
299
+
300
+
301
+ class CurrentProj(Projection):
302
+ """
303
+ Current-based projection of the neural network.
304
+
305
+ This projection directly modulates post-synaptic currents without separate synaptic dynamics.
306
+ It processes inputs through optional prefetch modules, applies a communication model,
307
+ and binds the result to the output model which is then added as a current input to the post-synaptic population.
308
+
309
+ Parameters
310
+ ----------
311
+ *prefetch : State or callable
312
+ Optional prefetch modules to process input before communication.
313
+ The last element must be an instance of Prefetch or PrefetchDelayAt if any are provided.
314
+ comm : callable
315
+ Communication model that determines how signals are transmitted.
316
+ out : SynOut
317
+ Output model that converts communication results to post-synaptic currents.
318
+ post : Dynamics
319
+ Post-synaptic neural population to receive the currents.
320
+
321
+ Examples
322
+ --------
323
+ >>> import brainstate
324
+ >>> import brainunit as u
325
+ >>> n_neurons = 100
326
+ >>> pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
327
+ >>> pop.init_state()
328
+ >>> current_input = brainstate.nn.CurrentProj(
329
+ ... comm=lambda x: x * 0.5,
330
+ ... out=brainstate.nn.CUBA(scale=1.0*u.nA),
331
+ ... post=pop
332
+ ... )
333
+ >>> current_input(0.2) # Apply external current
334
+ """
335
+ __module__ = 'brainstate.nn'
336
+
337
+ def __init__(
338
+ self,
339
+ *prefetch,
340
+ comm: Callable,
341
+ out: SynOut,
342
+ post: Dynamics,
343
+ ):
344
+ super().__init__(name=get_unique_name(self.__class__.__name__))
345
+
346
+ # check prefetch
347
+ self.prefetch = prefetch
348
+ if len(self.prefetch) > 0 and not isinstance(prefetch[-1], (Prefetch, PrefetchDelayAt)):
349
+ raise TypeError(
350
+ f'The last element of prefetch should be an instance of {Prefetch} or {PrefetchDelayAt}, '
351
+ f'but got {prefetch[-1]}.'
352
+ )
353
+
354
+ # check out
355
+ if not isinstance(out, SynOut):
356
+ raise TypeError(f'The out should be a SynOut, but got {out}.')
357
+ self.out = out
358
+
359
+ # check post
360
+ if not isinstance(post, Dynamics):
361
+ raise TypeError(f'The post should be a Dynamics, but got {post}.')
362
+ self.post = post
363
+ post.add_current_input(self.name, out)
364
+
365
+ # output initialization
366
+ self.comm = comm
367
+
368
+ @call_order(2)
369
+ def init_state(self, *args, **kwargs):
370
+ for prefetch in self.prefetch:
371
+ maybe_init_prefetch(prefetch, *args, **kwargs)
372
+
373
+ def update(self, *x):
374
+ for prefetch in self.prefetch:
375
+ x = (call_module(prefetch, *x),)
376
+ x = self.comm(*x)
377
+ self.out.bind_cond(x)
378
+
379
+
380
+ class align_pre_projection(Projection):
381
+ """
382
+ Represents a pre-synaptic alignment projection mechanism.
383
+
384
+ This class inherits from the `Projection` base class and is designed to
385
+ manage the pre-synaptic alignment process in neural network simulations.
386
+ It takes into account pre-synaptic dynamics, synaptic properties, delays,
387
+ communication functions, synaptic outputs, post-synaptic dynamics, and
388
+ short-term plasticity.
389
+
390
+ Attributes:
391
+ pre (Dynamics): The pre-synaptic dynamics object.
392
+ syn (Synapse): The synaptic object after pre-synaptic alignment.
393
+ delay (u.Quantity[u.second]): The output delay from the synapse.
394
+ projection (CurrentProj): The current projection object handling communication,
395
+ output, and post-synaptic dynamics.
396
+ stp (ShortTermPlasticity, optional): The short-term plasticity object,
397
+ defaults to None.
398
+ """
399
+
400
+ def __init__(
401
+ self,
402
+ *spike_generator,
403
+ syn: Dynamics,
404
+ comm: Callable,
405
+ out: SynOut,
406
+ post: Dynamics,
407
+ stp: ShortTermPlasticity = None,
408
+ ):
409
+ super().__init__()
410
+
411
+ self.spike_generator = _check_modules(*spike_generator)
412
+ self.projection = CurrentProj(comm=comm, out=out, post=post)
413
+ self.syn = syn
414
+ self.stp = stp
415
+
416
+ @call_order(2)
417
+ def init_state(self, *args, **kwargs):
418
+ for module in self.spike_generator:
419
+ maybe_init_prefetch(module, *args, **kwargs)
420
+
421
+ def update(self, *x):
422
+ for fun in self.spike_generator:
423
+ x = fun(*x)
424
+ if isinstance(x, (tuple, list)):
425
+ x = tuple(x)
426
+ else:
427
+ x = (x,)
428
+ assert len(x) == 1, "Spike generator must return a single value or a tuple/list of values"
429
+ x = brainevent.BinaryArray(x[0]) # Ensure input is a BinaryFloat for spike generation
430
+ if self.stp is not None:
431
+ x = brainevent.MaskedFloat(self.stp(x)) # Ensure STP output is a MaskedFloat
432
+ x = self.syn(x) # Apply pre-synaptic alignment
433
+ return self.projection(x)
434
+
435
+
436
+ class align_post_projection(Projection):
437
+ """
438
+ Represents a post-synaptic alignment projection mechanism.
439
+
440
+ This class inherits from the `Projection` base class and is designed to
441
+ manage the post-synaptic alignment process in neural network simulations.
442
+ It takes into account spike generators, communication functions, synaptic
443
+ properties, synaptic outputs, post-synaptic dynamics, and short-term plasticity.
444
+
445
+ Args:
446
+ *spike_generator: Callable(s) that generate spike events or transform input spikes.
447
+ comm (Callable): Communication function for the projection.
448
+ syn (Union[AlignPost, ParamDescriber[AlignPost]]): The post-synaptic alignment object or its parameter describer.
449
+ out (Union[SynOut, ParamDescriber[SynOut]]): The synaptic output object or its parameter describer.
450
+ post (Dynamics): The post-synaptic dynamics object.
451
+ stp (ShortTermPlasticity, optional): The short-term plasticity object, defaults to None.
452
+
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ *spike_generator,
458
+ comm: Callable,
459
+ syn: Union[AlignPost, ParamDescriber[AlignPost]],
460
+ out: Union[SynOut, ParamDescriber[SynOut]],
461
+ post: Dynamics,
462
+ stp: ShortTermPlasticity = None,
463
+ ):
464
+ super().__init__()
465
+
466
+ self.spike_generator = _check_modules(*spike_generator)
467
+ self.projection = AlignPostProj(comm=comm, syn=syn, out=out, post=post)
468
+ self.stp = stp
469
+
470
+ @call_order(2)
471
+ def init_state(self, *args, **kwargs):
472
+ for module in self.spike_generator:
473
+ maybe_init_prefetch(module, *args, **kwargs)
474
+
475
+ def update(self, *x):
476
+ for fun in self.spike_generator:
477
+ x = fun(*x)
478
+ if isinstance(x, (tuple, list)):
479
+ x = tuple(x)
480
+ else:
481
+ x = (x,)
482
+ assert len(x) == 1, "Spike generator must return a single value or a tuple/list of values"
483
+ x = brainevent.BinaryArray(x[0]) # Ensure input is a BinaryFloat for spike generation
484
+ if self.stp is not None:
485
+ x = brainevent.MaskedFloat(self.stp(x)) # Ensure STP output is a MaskedFloat
486
+ return self.projection(x)