brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/_inputs.py DELETED
@@ -1,608 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import Union, Optional, Sequence, Callable
17
-
18
- import brainunit as u
19
- import jax
20
- import numpy as np
21
-
22
- from brainstate import environ, init, random
23
- from brainstate._state import ShortTermState, State, maybe_state
24
- from brainstate.compile import while_loop
25
- from brainstate.typing import ArrayLike, Size, DTypeLike
26
- from ._dynamics import Dynamics, Prefetch
27
- from ._module import Module
28
-
29
- __all__ = [
30
- 'SpikeTime',
31
- 'PoissonSpike',
32
- 'PoissonEncoder',
33
- 'PoissonInput',
34
- 'poisson_input',
35
- ]
36
-
37
-
38
- class SpikeTime(Dynamics):
39
- """The input neuron group characterized by spikes emitting at given times.
40
-
41
- >>> # Get 2 neurons, firing spikes at 10 ms and 20 ms.
42
- >>> SpikeTime(2, times=[10, 20])
43
- >>> # or
44
- >>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms.
45
- >>> SpikeTime(2, times=[10, 20], indices=[0, 0])
46
- >>> # or
47
- >>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms.
48
- >>> SpikeTime(2, times=[10, 20, 30], indices=[0, 1, 0])
49
- >>> # or
50
- >>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire;
51
- >>> # at 30 ms, neuron 1 fires.
52
- >>> SpikeTime(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])
53
-
54
- Parameters
55
- ----------
56
- in_size : int, tuple, list
57
- The neuron group geometry.
58
- indices : list, tuple, ArrayType
59
- The neuron indices at each time point to emit spikes.
60
- times : list, tuple, ArrayType
61
- The time points which generate the spikes.
62
- name : str, optional
63
- The name of the dynamic system.
64
- """
65
-
66
- def __init__(
67
- self,
68
- in_size: Size,
69
- indices: Union[Sequence, ArrayLike],
70
- times: Union[Sequence, ArrayLike],
71
- spk_type: DTypeLike = bool,
72
- name: Optional[str] = None,
73
- need_sort: bool = True,
74
- ):
75
- super().__init__(in_size=in_size, name=name)
76
-
77
- # parameters
78
- if len(indices) != len(times):
79
- raise ValueError(f'The length of "indices" and "times" must be the same. '
80
- f'However, we got {len(indices)} != {len(times)}.')
81
- self.num_times = len(times)
82
- self.spk_type = spk_type
83
-
84
- # data about times and indices
85
- self.times = u.math.asarray(times)
86
- self.indices = u.math.asarray(indices, dtype=environ.ditype())
87
- if need_sort:
88
- sort_idx = u.math.argsort(self.times)
89
- self.indices = self.indices[sort_idx]
90
- self.times = self.times[sort_idx]
91
-
92
- def init_state(self, *args, **kwargs):
93
- self.i = ShortTermState(-1)
94
-
95
- def reset_state(self, batch_size=None, **kwargs):
96
- self.i.value = -1
97
-
98
- def update(self):
99
- t = environ.get('t')
100
-
101
- def _cond_fun(spikes):
102
- i = self.i.value
103
- return u.math.logical_and(i < self.num_times, t >= self.times[i])
104
-
105
- def _body_fun(spikes):
106
- i = self.i.value
107
- spikes = spikes.at[..., self.indices[i]].set(True)
108
- self.i.value += 1
109
- return spikes
110
-
111
- spike = u.math.zeros(self.varshape, dtype=self.spk_type)
112
- spike = while_loop(_cond_fun, _body_fun, spike)
113
- return spike
114
-
115
-
116
- class PoissonSpike(Dynamics):
117
- """
118
- Poisson Neuron Group.
119
- """
120
-
121
- def __init__(
122
- self,
123
- in_size: Size,
124
- freqs: Union[ArrayLike, Callable],
125
- spk_type: DTypeLike = bool,
126
- name: Optional[str] = None,
127
- ):
128
- super().__init__(in_size=in_size, name=name)
129
-
130
- self.spk_type = spk_type
131
-
132
- # parameters
133
- self.freqs = init.param(freqs, self.varshape, allow_none=False)
134
-
135
- def update(self):
136
- spikes = random.rand(*self.varshape) <= (self.freqs * environ.get_dt())
137
- spikes = u.math.asarray(spikes, dtype=self.spk_type)
138
- return spikes
139
-
140
-
141
- class PoissonEncoder(Dynamics):
142
- r"""Poisson spike encoder for converting firing rates to spike trains.
143
-
144
- This class implements a Poisson process to generate spikes based on provided
145
- firing rates. Unlike the PoissonSpike class, this encoder accepts firing rates
146
- as input during the update step rather than having them fixed at initialization.
147
-
148
- The spike generation follows a Poisson process where the probability of a spike
149
- in each time step is proportional to the firing rate and the simulation time step:
150
-
151
- $$
152
- P(\text{spike}) = \text{rate} \cdot \text{dt}
153
- $$
154
-
155
- For each neuron and time step, the encoder draws a random number from a uniform
156
- distribution [0,1] and generates a spike if the number is less than or equal to
157
- the spiking probability.
158
-
159
- Parameters
160
- ----------
161
- in_size : Size
162
- Size of the input to the encoder, defining the shape of the output spike train.
163
- spk_type : DTypeLike, default=bool
164
- Data type for the generated spikes. Typically boolean for binary spikes.
165
- name : str, optional
166
- Name of the encoder module.
167
-
168
- Examples
169
- --------
170
- >>> import brainstate as bs
171
- >>> import brainunit as u
172
- >>> import numpy as np
173
- >>>
174
- >>> # Create a Poisson encoder for 10 neurons
175
- >>> encoder = bs.nn.PoissonEncoder(10)
176
- >>>
177
- >>> # Generate spikes with varying firing rates
178
- >>> rates = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100]) * u.Hz
179
- >>> spikes = encoder.update(rates)
180
- >>>
181
- >>> # Use in a more complex processing pipeline
182
- >>> # First, generate rate-coded output from an analog signal
183
- >>> analog_values = np.random.rand(10) * 100 # values between 0 and 100
184
- >>> firing_rates = analog_values * u.Hz # convert to firing rates
185
- >>> spike_train = encoder.update(firing_rates)
186
- >>>
187
- >>> # Feed the spikes into a spiking neural network
188
- >>> neuron_layer = bs.nn.LIF(10)
189
- >>> neuron_layer.init_state(batch_size=1)
190
- >>> output_spikes = neuron_layer.update(spike_train)
191
-
192
- Notes
193
- -----
194
- - This encoder is particularly useful for rate-to-spike conversion in neuromorphic
195
- computing applications and sensory encoding tasks.
196
- - The statistical properties of the generated spike trains follow a Poisson process,
197
- where the inter-spike intervals are exponentially distributed.
198
- - For small time steps (dt), the number of spikes in a time window T approximately
199
- follows a Poisson distribution with parameter λ = rate * T.
200
- - Unlike PoissonSpike which has fixed rates, this encoder allows dynamic rate changes
201
- with every update call, making it suitable for encoding time-varying signals.
202
- - The independence of spike generation between time steps results in renewal process
203
- statistics without memory of previous spiking history.
204
- """
205
-
206
- def __init__(
207
- self,
208
- in_size: Size,
209
- spk_type: DTypeLike = bool,
210
- name: Optional[str] = None,
211
- ):
212
- super().__init__(in_size=in_size, name=name)
213
- self.spk_type = spk_type
214
-
215
- def update(self, freqs: ArrayLike):
216
- spikes = random.rand(*self.varshape) <= (freqs * environ.get_dt())
217
- spikes = u.math.asarray(spikes, dtype=self.spk_type)
218
- return spikes
219
-
220
-
221
- class PoissonInput(Module):
222
- r"""Poisson Input to the given state variable.
223
-
224
- This class provides a way to add independent Poisson-distributed spiking input
225
- to a target state variable. For large numbers of inputs, this implementation is
226
- computationally more efficient than creating separate Poisson spike generators.
227
-
228
- The synaptic events are generated randomly during simulation runtime and are not
229
- preloaded or stored in memory, which improves memory efficiency for large-scale
230
- simulations. All inputs target the same variable with the same frequency and
231
- synaptic weight.
232
-
233
- The Poisson process generates spikes with probability based on the frequency and
234
- simulation time step:
235
-
236
- $$
237
- P(\text{spike}) = \text{freq} \cdot \text{dt}
238
- $$
239
-
240
- For computational efficiency, two different methods are used for spike generation:
241
-
242
- 1. For large numbers of inputs, a normal approximation:
243
- $$
244
- \text{inputs} \sim \mathcal{N}(\mu, \sigma^2)
245
- $$
246
- where $\mu = \text{num\_input} \cdot p$ and $\sigma^2 = \text{num\_input} \cdot p \cdot (1-p)$
247
-
248
- 2. For smaller numbers, a direct binomial sampling:
249
- $$
250
- \text{inputs} \sim \text{Binomial}(\text{num\_input}, p)
251
- $$
252
-
253
- where $p = \text{freq} \cdot \text{dt}$ in both cases.
254
-
255
- Parameters
256
- ----------
257
- target : Prefetch
258
- The variable that is targeted by this input. Should be an instance of
259
- :py:class:`brainstate.State` that's prefetched via the target mechanism.
260
- indices : Union[np.ndarray, jax.Array]
261
- Indices of the target to receive input. If None, input is applied to the entire target.
262
- num_input : int
263
- The number of independent Poisson input sources.
264
- freq : Union[int, float]
265
- The firing frequency of each input source in Hz.
266
- weight : ndarray, float, or brainunit.Quantity
267
- The synaptic weight of each input spike.
268
- name : Optional[str], optional
269
- The name of this module.
270
-
271
- Examples
272
- --------
273
- >>> import brainstate as bs
274
- >>> import brainunit as u
275
- >>> import numpy as np
276
- >>>
277
- >>> # Create a neuron group with membrane potential
278
- >>> neuron = bs.nn.LIF(100)
279
- >>> neuron.init_state(batch_size=1)
280
- >>>
281
- >>> # Add Poisson input to all neurons
282
- >>> poisson_in = bs.nn.PoissonInput(
283
- ... target=neuron.V,
284
- ... indices=None,
285
- ... num_input=200,
286
- ... freq=50 * u.Hz,
287
- ... weight=0.1 * u.mV
288
- ... )
289
- >>>
290
- >>> # Add Poisson input only to specific neurons
291
- >>> indices = np.array([0, 10, 20, 30])
292
- >>> specific_input = bs.nn.PoissonInput(
293
- ... target=neuron.V,
294
- ... indices=indices,
295
- ... num_input=50,
296
- ... freq=100 * u.Hz,
297
- ... weight=0.2 * u.mV
298
- ... )
299
- >>>
300
- >>> # Run simulation with the inputs
301
- >>> for t in range(100):
302
- ... poisson_in.update()
303
- ... specific_input.update()
304
- ... neuron.update()
305
-
306
- Notes
307
- -----
308
- - The Poisson inputs are statistically independent between update steps and across
309
- target neurons.
310
- - This implementation is particularly efficient for large numbers of inputs or targets.
311
- - For very sparse connectivity patterns, consider using individual PoissonSpike neurons
312
- with specific connectivity patterns instead.
313
- - The update method internally calls the poisson_input function which handles the
314
- spike generation and target state updates.
315
- """
316
-
317
- def __init__(
318
- self,
319
- target: Prefetch,
320
- indices: Union[np.ndarray, jax.Array],
321
- num_input: int,
322
- freq: u.Quantity[u.Hz],
323
- weight: Union[jax.typing.ArrayLike, u.Quantity],
324
- name: Optional[str] = None,
325
- ):
326
- super().__init__(name=name)
327
-
328
- self.target = target
329
- self.indices = indices
330
- self.num_input = num_input
331
- self.freq = freq
332
- self.weight = weight
333
-
334
- def update(self):
335
- target_state = getattr(self.target.module, self.target.item)
336
-
337
- # generate Poisson input
338
- poisson_input(
339
- self.freq,
340
- self.num_input,
341
- self.weight,
342
- target_state,
343
- self.indices,
344
- )
345
-
346
-
347
- def poisson_input(
348
- freq: u.Quantity[u.Hz],
349
- num_input: int,
350
- weight: Union[jax.typing.ArrayLike, u.Quantity],
351
- target: State,
352
- indices: Optional[Union[np.ndarray, jax.Array]] = None,
353
- refractory: Optional[Union[jax.Array]] = None,
354
- ):
355
- r"""Generates Poisson-distributed input spikes to a target state variable.
356
-
357
- This function simulates Poisson input to a given state, updating the target
358
- variable with generated spikes based on the specified frequency, number of inputs,
359
- and synaptic weight. The input can be applied to specific indices of the target
360
- or to the entire target if indices are not provided.
361
-
362
- The function uses two different methods to generate the Poisson-distributed input:
363
- 1. For large numbers of inputs (a > 5 and b > 5), a normal approximation is used
364
- 2. For smaller numbers, a direct binomial sampling approach is used
365
-
366
- Mathematical model for Poisson input:
367
- $$
368
- P(\text{spike}) = \text{freq} \cdot \text{dt}
369
- $$
370
-
371
- For the normal approximation (when a > 5 and b > 5):
372
- $$
373
- \text{inputs} \sim \mathcal{N}(a, b \cdot p)
374
- $$
375
- where:
376
- $$
377
- a = \text{num\_input} \cdot p
378
- $$
379
- $$
380
- b = \text{num\_input} \cdot (1 - p)
381
- $$
382
- $$
383
- p = \text{freq} \cdot \text{dt}
384
- $$
385
-
386
- For direct binomial sampling (when a ≤ 5 or b ≤ 5):
387
- $$
388
- \text{inputs} \sim \text{Binomial}(\text{num\_input}, p)
389
- $$
390
-
391
- Parameters
392
- ----------
393
- freq : u.Quantity[u.Hz]
394
- The frequency of the Poisson input in Hertz.
395
- num_input : int
396
- The number of input channels or neurons generating the Poisson spikes.
397
- weight : u.Quantity
398
- The synaptic weight applied to each spike.
399
- target : State
400
- The target state variable to which the Poisson input is applied.
401
- indices : Optional[Union[np.ndarray, jax.Array]], optional
402
- Specific indices of the target to apply the input. If None, the input is applied
403
- to the entire target.
404
- refractory : Optional[Union[jax.Array]], optional
405
- A boolean array indicating which parts of the target are in a refractory state
406
- and should not be updated. Should be the same length as the target.
407
-
408
- Examples
409
- --------
410
- >>> import brainstate as bs
411
- >>> import brainunit as u
412
- >>> import numpy as np
413
- >>>
414
- >>> # Create a membrane potential state
415
- >>> V = bs.HiddenState(np.zeros(100) * u.mV)
416
- >>>
417
- >>> # Add Poisson input to all neurons at 50 Hz
418
- >>> bs.nn.poisson_input(
419
- ... freq=50 * u.Hz,
420
- ... num_input=200,
421
- ... weight=0.1 * u.mV,
422
- ... target=V
423
- ... )
424
- >>>
425
- >>> # Apply Poisson input only to a subset of neurons
426
- >>> indices = np.array([0, 10, 20, 30])
427
- >>> bs.nn.poisson_input(
428
- ... freq=100 * u.Hz,
429
- ... num_input=50,
430
- ... weight=0.2 * u.mV,
431
- ... target=V,
432
- ... indices=indices
433
- ... )
434
- >>>
435
- >>> # Apply input with refractory mask
436
- >>> refractory = np.zeros(100, dtype=bool)
437
- >>> refractory[40:60] = True # neurons 40-59 are in refractory period
438
- >>> bs.nn.poisson_input(
439
- ... freq=75 * u.Hz,
440
- ... num_input=100,
441
- ... weight=0.15 * u.mV,
442
- ... target=V,
443
- ... refractory=refractory
444
- ... )
445
-
446
- Notes
447
- -----
448
- - The function automatically switches between normal approximation and binomial
449
- sampling based on the input parameters to optimize computation efficiency.
450
- - For large numbers of inputs, the normal approximation provides significant
451
- performance improvements.
452
- - The weight parameter is applied uniformly to all generated spikes.
453
- - When refractory is provided, the corresponding target elements are not updated.
454
- """
455
- freq = maybe_state(freq)
456
- weight = maybe_state(weight)
457
-
458
- assert isinstance(target, State), 'The target must be a State.'
459
- p = freq * environ.get_dt()
460
- p = p.to_decimal() if isinstance(p, u.Quantity) else p
461
- a = num_input * p
462
- b = num_input * (1 - p)
463
- tar_val = target.value
464
- cond = u.math.logical_and(a > 5, b > 5)
465
-
466
- if indices is None:
467
- # generate Poisson input
468
- branch1 = jax.tree.map(
469
- lambda tar: random.normal(
470
- a,
471
- b * p,
472
- tar.shape,
473
- dtype=tar.dtype
474
- ),
475
- tar_val,
476
- is_leaf=u.math.is_quantity
477
- )
478
- branch2 = jax.tree.map(
479
- lambda tar: random.binomial(
480
- num_input,
481
- p,
482
- tar.shape,
483
- check_valid=False,
484
- dtype=tar.dtype
485
- ),
486
- tar_val,
487
- is_leaf=u.math.is_quantity,
488
- )
489
-
490
- inp = jax.tree.map(
491
- lambda b1, b2: u.math.where(cond, b1, b2),
492
- branch1,
493
- branch2,
494
- is_leaf=u.math.is_quantity,
495
- )
496
-
497
- # inp = jax.lax.cond(
498
- # cond,
499
- # lambda rand_key: jax.tree.map(
500
- # lambda tar: random.normal(
501
- # a,
502
- # b * p,
503
- # tar.shape,
504
- # key=rand_key,
505
- # dtype=tar.dtype
506
- # ),
507
- # tar_val,
508
- # is_leaf=u.math.is_quantity
509
- # ),
510
- # lambda rand_key: jax.tree.map(
511
- # lambda tar: random.binomial(
512
- # num_input,
513
- # p,
514
- # tar.shape,
515
- # key=rand_key,
516
- # check_valid=False,
517
- # dtype=tar.dtype
518
- # ),
519
- # tar_val,
520
- # is_leaf=u.math.is_quantity,
521
- # ),
522
- # random.split_key()
523
- # )
524
-
525
- # update target variable
526
- data = jax.tree.map(
527
- lambda tar, x: tar + x * weight,
528
- target.value,
529
- inp,
530
- is_leaf=u.math.is_quantity
531
- )
532
-
533
- else:
534
- # generate Poisson input
535
- branch1 = jax.tree.map(
536
- lambda tar: random.normal(
537
- a,
538
- b * p,
539
- tar[indices].shape,
540
- dtype=tar.dtype
541
- ),
542
- tar_val,
543
- is_leaf=u.math.is_quantity
544
- )
545
- branch2 = jax.tree.map(
546
- lambda tar: random.binomial(
547
- num_input,
548
- p,
549
- tar[indices].shape,
550
- check_valid=False,
551
- dtype=tar.dtype
552
- ),
553
- tar_val,
554
- is_leaf=u.math.is_quantity
555
- )
556
-
557
- inp = jax.tree.map(
558
- lambda b1, b2: u.math.where(cond, b1, b2),
559
- branch1,
560
- branch2,
561
- is_leaf=u.math.is_quantity,
562
- )
563
-
564
- # inp = jax.lax.cond(
565
- # cond,
566
- # lambda rand_key: jax.tree.map(
567
- # lambda tar: random.normal(
568
- # a,
569
- # b * p,
570
- # tar[indices].shape,
571
- # key=rand_key,
572
- # dtype=tar.dtype
573
- # ),
574
- # tar_val,
575
- # is_leaf=u.math.is_quantity
576
- # ),
577
- # lambda rand_key: jax.tree.map(
578
- # lambda tar: random.binomial(
579
- # num_input,
580
- # p,
581
- # tar[indices].shape,
582
- # key=rand_key,
583
- # check_valid=False,
584
- # dtype=tar.dtype
585
- # ),
586
- # tar_val,
587
- # is_leaf=u.math.is_quantity
588
- # ),
589
- # random.split_key()
590
- # )
591
-
592
- # update target variable
593
- data = jax.tree.map(
594
- lambda x, tar: tar.at[indices].add(x * weight),
595
- inp,
596
- tar_val,
597
- is_leaf=u.math.is_quantity
598
- )
599
-
600
- if refractory is not None:
601
- target.value = jax.tree.map(
602
- lambda x, tar: u.math.where(refractory, tar, x),
603
- data,
604
- tar_val,
605
- is_leaf=u.math.is_quantity
606
- )
607
- else:
608
- target.value = data
brainstate/nn/_ltp.py DELETED
@@ -1,28 +0,0 @@
1
- # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
-
19
- from ._synapse import Synapse
20
-
21
- __all__ = [
22
- 'LongTermPlasticity',
23
- ]
24
-
25
-
26
- class LongTermPlasticity(Synapse):
27
- pass
28
-