brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_inputs.py DELETED
@@ -1,608 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import Union, Optional, Sequence, Callable
17
-
18
- import brainunit as u
19
- import jax
20
- import numpy as np
21
-
22
- from brainstate import environ, init, random
23
- from brainstate._state import ShortTermState, State, maybe_state
24
- from brainstate.compile import while_loop
25
- from brainstate.typing import ArrayLike, Size, DTypeLike
26
- from ._dynamics import Dynamics, Prefetch
27
- from ._module import Module
28
-
29
- __all__ = [
30
- 'SpikeTime',
31
- 'PoissonSpike',
32
- 'PoissonEncoder',
33
- 'PoissonInput',
34
- 'poisson_input',
35
- ]
36
-
37
-
38
- class SpikeTime(Dynamics):
39
- """The input neuron group characterized by spikes emitting at given times.
40
-
41
- >>> # Get 2 neurons, firing spikes at 10 ms and 20 ms.
42
- >>> SpikeTime(2, times=[10, 20])
43
- >>> # or
44
- >>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms.
45
- >>> SpikeTime(2, times=[10, 20], indices=[0, 0])
46
- >>> # or
47
- >>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms.
48
- >>> SpikeTime(2, times=[10, 20, 30], indices=[0, 1, 0])
49
- >>> # or
50
- >>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire;
51
- >>> # at 30 ms, neuron 1 fires.
52
- >>> SpikeTime(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])
53
-
54
- Parameters
55
- ----------
56
- in_size : int, tuple, list
57
- The neuron group geometry.
58
- indices : list, tuple, ArrayType
59
- The neuron indices at each time point to emit spikes.
60
- times : list, tuple, ArrayType
61
- The time points which generate the spikes.
62
- name : str, optional
63
- The name of the dynamic system.
64
- """
65
-
66
- def __init__(
67
- self,
68
- in_size: Size,
69
- indices: Union[Sequence, ArrayLike],
70
- times: Union[Sequence, ArrayLike],
71
- spk_type: DTypeLike = bool,
72
- name: Optional[str] = None,
73
- need_sort: bool = True,
74
- ):
75
- super().__init__(in_size=in_size, name=name)
76
-
77
- # parameters
78
- if len(indices) != len(times):
79
- raise ValueError(f'The length of "indices" and "times" must be the same. '
80
- f'However, we got {len(indices)} != {len(times)}.')
81
- self.num_times = len(times)
82
- self.spk_type = spk_type
83
-
84
- # data about times and indices
85
- self.times = u.math.asarray(times)
86
- self.indices = u.math.asarray(indices, dtype=environ.ditype())
87
- if need_sort:
88
- sort_idx = u.math.argsort(self.times)
89
- self.indices = self.indices[sort_idx]
90
- self.times = self.times[sort_idx]
91
-
92
- def init_state(self, *args, **kwargs):
93
- self.i = ShortTermState(-1)
94
-
95
- def reset_state(self, batch_size=None, **kwargs):
96
- self.i.value = -1
97
-
98
- def update(self):
99
- t = environ.get('t')
100
-
101
- def _cond_fun(spikes):
102
- i = self.i.value
103
- return u.math.logical_and(i < self.num_times, t >= self.times[i])
104
-
105
- def _body_fun(spikes):
106
- i = self.i.value
107
- spikes = spikes.at[..., self.indices[i]].set(True)
108
- self.i.value += 1
109
- return spikes
110
-
111
- spike = u.math.zeros(self.varshape, dtype=self.spk_type)
112
- spike = while_loop(_cond_fun, _body_fun, spike)
113
- return spike
114
-
115
-
116
- class PoissonSpike(Dynamics):
117
- """
118
- Poisson Neuron Group.
119
- """
120
-
121
- def __init__(
122
- self,
123
- in_size: Size,
124
- freqs: Union[ArrayLike, Callable],
125
- spk_type: DTypeLike = bool,
126
- name: Optional[str] = None,
127
- ):
128
- super().__init__(in_size=in_size, name=name)
129
-
130
- self.spk_type = spk_type
131
-
132
- # parameters
133
- self.freqs = init.param(freqs, self.varshape, allow_none=False)
134
-
135
- def update(self):
136
- spikes = random.rand(*self.varshape) <= (self.freqs * environ.get_dt())
137
- spikes = u.math.asarray(spikes, dtype=self.spk_type)
138
- return spikes
139
-
140
-
141
- class PoissonEncoder(Dynamics):
142
- r"""Poisson spike encoder for converting firing rates to spike trains.
143
-
144
- This class implements a Poisson process to generate spikes based on provided
145
- firing rates. Unlike the PoissonSpike class, this encoder accepts firing rates
146
- as input during the update step rather than having them fixed at initialization.
147
-
148
- The spike generation follows a Poisson process where the probability of a spike
149
- in each time step is proportional to the firing rate and the simulation time step:
150
-
151
- $$
152
- P(\text{spike}) = \text{rate} \cdot \text{dt}
153
- $$
154
-
155
- For each neuron and time step, the encoder draws a random number from a uniform
156
- distribution [0,1] and generates a spike if the number is less than or equal to
157
- the spiking probability.
158
-
159
- Parameters
160
- ----------
161
- in_size : Size
162
- Size of the input to the encoder, defining the shape of the output spike train.
163
- spk_type : DTypeLike, default=bool
164
- Data type for the generated spikes. Typically boolean for binary spikes.
165
- name : str, optional
166
- Name of the encoder module.
167
-
168
- Examples
169
- --------
170
- >>> import brainstate as bs
171
- >>> import brainunit as u
172
- >>> import numpy as np
173
- >>>
174
- >>> # Create a Poisson encoder for 10 neurons
175
- >>> encoder = bs.nn.PoissonEncoder(10)
176
- >>>
177
- >>> # Generate spikes with varying firing rates
178
- >>> rates = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100]) * u.Hz
179
- >>> spikes = encoder.update(rates)
180
- >>>
181
- >>> # Use in a more complex processing pipeline
182
- >>> # First, generate rate-coded output from an analog signal
183
- >>> analog_values = np.random.rand(10) * 100 # values between 0 and 100
184
- >>> firing_rates = analog_values * u.Hz # convert to firing rates
185
- >>> spike_train = encoder.update(firing_rates)
186
- >>>
187
- >>> # Feed the spikes into a spiking neural network
188
- >>> neuron_layer = bs.nn.LIF(10)
189
- >>> neuron_layer.init_state(batch_size=1)
190
- >>> output_spikes = neuron_layer.update(spike_train)
191
-
192
- Notes
193
- -----
194
- - This encoder is particularly useful for rate-to-spike conversion in neuromorphic
195
- computing applications and sensory encoding tasks.
196
- - The statistical properties of the generated spike trains follow a Poisson process,
197
- where the inter-spike intervals are exponentially distributed.
198
- - For small time steps (dt), the number of spikes in a time window T approximately
199
- follows a Poisson distribution with parameter λ = rate * T.
200
- - Unlike PoissonSpike which has fixed rates, this encoder allows dynamic rate changes
201
- with every update call, making it suitable for encoding time-varying signals.
202
- - The independence of spike generation between time steps results in renewal process
203
- statistics without memory of previous spiking history.
204
- """
205
-
206
- def __init__(
207
- self,
208
- in_size: Size,
209
- spk_type: DTypeLike = bool,
210
- name: Optional[str] = None,
211
- ):
212
- super().__init__(in_size=in_size, name=name)
213
- self.spk_type = spk_type
214
-
215
- def update(self, freqs: ArrayLike):
216
- spikes = random.rand(*self.varshape) <= (freqs * environ.get_dt())
217
- spikes = u.math.asarray(spikes, dtype=self.spk_type)
218
- return spikes
219
-
220
-
221
- class PoissonInput(Module):
222
- r"""Poisson Input to the given state variable.
223
-
224
- This class provides a way to add independent Poisson-distributed spiking input
225
- to a target state variable. For large numbers of inputs, this implementation is
226
- computationally more efficient than creating separate Poisson spike generators.
227
-
228
- The synaptic events are generated randomly during simulation runtime and are not
229
- preloaded or stored in memory, which improves memory efficiency for large-scale
230
- simulations. All inputs target the same variable with the same frequency and
231
- synaptic weight.
232
-
233
- The Poisson process generates spikes with probability based on the frequency and
234
- simulation time step:
235
-
236
- $$
237
- P(\text{spike}) = \text{freq} \cdot \text{dt}
238
- $$
239
-
240
- For computational efficiency, two different methods are used for spike generation:
241
-
242
- 1. For large numbers of inputs, a normal approximation:
243
- $$
244
- \text{inputs} \sim \mathcal{N}(\mu, \sigma^2)
245
- $$
246
- where $\mu = \text{num\_input} \cdot p$ and $\sigma^2 = \text{num\_input} \cdot p \cdot (1-p)$
247
-
248
- 2. For smaller numbers, a direct binomial sampling:
249
- $$
250
- \text{inputs} \sim \text{Binomial}(\text{num\_input}, p)
251
- $$
252
-
253
- where $p = \text{freq} \cdot \text{dt}$ in both cases.
254
-
255
- Parameters
256
- ----------
257
- target : Prefetch
258
- The variable that is targeted by this input. Should be an instance of
259
- :py:class:`brainstate.State` that's prefetched via the target mechanism.
260
- indices : Union[np.ndarray, jax.Array]
261
- Indices of the target to receive input. If None, input is applied to the entire target.
262
- num_input : int
263
- The number of independent Poisson input sources.
264
- freq : Union[int, float]
265
- The firing frequency of each input source in Hz.
266
- weight : ndarray, float, or brainunit.Quantity
267
- The synaptic weight of each input spike.
268
- name : Optional[str], optional
269
- The name of this module.
270
-
271
- Examples
272
- --------
273
- >>> import brainstate as bs
274
- >>> import brainunit as u
275
- >>> import numpy as np
276
- >>>
277
- >>> # Create a neuron group with membrane potential
278
- >>> neuron = bs.nn.LIF(100)
279
- >>> neuron.init_state(batch_size=1)
280
- >>>
281
- >>> # Add Poisson input to all neurons
282
- >>> poisson_in = bs.nn.PoissonInput(
283
- ... target=neuron.V,
284
- ... indices=None,
285
- ... num_input=200,
286
- ... freq=50 * u.Hz,
287
- ... weight=0.1 * u.mV
288
- ... )
289
- >>>
290
- >>> # Add Poisson input only to specific neurons
291
- >>> indices = np.array([0, 10, 20, 30])
292
- >>> specific_input = bs.nn.PoissonInput(
293
- ... target=neuron.V,
294
- ... indices=indices,
295
- ... num_input=50,
296
- ... freq=100 * u.Hz,
297
- ... weight=0.2 * u.mV
298
- ... )
299
- >>>
300
- >>> # Run simulation with the inputs
301
- >>> for t in range(100):
302
- ... poisson_in.update()
303
- ... specific_input.update()
304
- ... neuron.update()
305
-
306
- Notes
307
- -----
308
- - The Poisson inputs are statistically independent between update steps and across
309
- target neurons.
310
- - This implementation is particularly efficient for large numbers of inputs or targets.
311
- - For very sparse connectivity patterns, consider using individual PoissonSpike neurons
312
- with specific connectivity patterns instead.
313
- - The update method internally calls the poisson_input function which handles the
314
- spike generation and target state updates.
315
- """
316
-
317
- def __init__(
318
- self,
319
- target: Prefetch,
320
- indices: Union[np.ndarray, jax.Array],
321
- num_input: int,
322
- freq: u.Quantity[u.Hz],
323
- weight: Union[jax.typing.ArrayLike, u.Quantity],
324
- name: Optional[str] = None,
325
- ):
326
- super().__init__(name=name)
327
-
328
- self.target = target
329
- self.indices = indices
330
- self.num_input = num_input
331
- self.freq = freq
332
- self.weight = weight
333
-
334
- def update(self):
335
- target_state = getattr(self.target.module, self.target.item)
336
-
337
- # generate Poisson input
338
- poisson_input(
339
- self.freq,
340
- self.num_input,
341
- self.weight,
342
- target_state,
343
- self.indices,
344
- )
345
-
346
-
347
- def poisson_input(
348
- freq: u.Quantity[u.Hz],
349
- num_input: int,
350
- weight: Union[jax.typing.ArrayLike, u.Quantity],
351
- target: State,
352
- indices: Optional[Union[np.ndarray, jax.Array]] = None,
353
- refractory: Optional[Union[jax.Array]] = None,
354
- ):
355
- r"""Generates Poisson-distributed input spikes to a target state variable.
356
-
357
- This function simulates Poisson input to a given state, updating the target
358
- variable with generated spikes based on the specified frequency, number of inputs,
359
- and synaptic weight. The input can be applied to specific indices of the target
360
- or to the entire target if indices are not provided.
361
-
362
- The function uses two different methods to generate the Poisson-distributed input:
363
- 1. For large numbers of inputs (a > 5 and b > 5), a normal approximation is used
364
- 2. For smaller numbers, a direct binomial sampling approach is used
365
-
366
- Mathematical model for Poisson input:
367
- $$
368
- P(\text{spike}) = \text{freq} \cdot \text{dt}
369
- $$
370
-
371
- For the normal approximation (when a > 5 and b > 5):
372
- $$
373
- \text{inputs} \sim \mathcal{N}(a, b \cdot p)
374
- $$
375
- where:
376
- $$
377
- a = \text{num\_input} \cdot p
378
- $$
379
- $$
380
- b = \text{num\_input} \cdot (1 - p)
381
- $$
382
- $$
383
- p = \text{freq} \cdot \text{dt}
384
- $$
385
-
386
- For direct binomial sampling (when a ≤ 5 or b ≤ 5):
387
- $$
388
- \text{inputs} \sim \text{Binomial}(\text{num\_input}, p)
389
- $$
390
-
391
- Parameters
392
- ----------
393
- freq : u.Quantity[u.Hz]
394
- The frequency of the Poisson input in Hertz.
395
- num_input : int
396
- The number of input channels or neurons generating the Poisson spikes.
397
- weight : u.Quantity
398
- The synaptic weight applied to each spike.
399
- target : State
400
- The target state variable to which the Poisson input is applied.
401
- indices : Optional[Union[np.ndarray, jax.Array]], optional
402
- Specific indices of the target to apply the input. If None, the input is applied
403
- to the entire target.
404
- refractory : Optional[Union[jax.Array]], optional
405
- A boolean array indicating which parts of the target are in a refractory state
406
- and should not be updated. Should be the same length as the target.
407
-
408
- Examples
409
- --------
410
- >>> import brainstate as bs
411
- >>> import brainunit as u
412
- >>> import numpy as np
413
- >>>
414
- >>> # Create a membrane potential state
415
- >>> V = bs.HiddenState(np.zeros(100) * u.mV)
416
- >>>
417
- >>> # Add Poisson input to all neurons at 50 Hz
418
- >>> bs.nn.poisson_input(
419
- ... freq=50 * u.Hz,
420
- ... num_input=200,
421
- ... weight=0.1 * u.mV,
422
- ... target=V
423
- ... )
424
- >>>
425
- >>> # Apply Poisson input only to a subset of neurons
426
- >>> indices = np.array([0, 10, 20, 30])
427
- >>> bs.nn.poisson_input(
428
- ... freq=100 * u.Hz,
429
- ... num_input=50,
430
- ... weight=0.2 * u.mV,
431
- ... target=V,
432
- ... indices=indices
433
- ... )
434
- >>>
435
- >>> # Apply input with refractory mask
436
- >>> refractory = np.zeros(100, dtype=bool)
437
- >>> refractory[40:60] = True # neurons 40-59 are in refractory period
438
- >>> bs.nn.poisson_input(
439
- ... freq=75 * u.Hz,
440
- ... num_input=100,
441
- ... weight=0.15 * u.mV,
442
- ... target=V,
443
- ... refractory=refractory
444
- ... )
445
-
446
- Notes
447
- -----
448
- - The function automatically switches between normal approximation and binomial
449
- sampling based on the input parameters to optimize computation efficiency.
450
- - For large numbers of inputs, the normal approximation provides significant
451
- performance improvements.
452
- - The weight parameter is applied uniformly to all generated spikes.
453
- - When refractory is provided, the corresponding target elements are not updated.
454
- """
455
- freq = maybe_state(freq)
456
- weight = maybe_state(weight)
457
-
458
- assert isinstance(target, State), 'The target must be a State.'
459
- p = freq * environ.get_dt()
460
- p = p.to_decimal() if isinstance(p, u.Quantity) else p
461
- a = num_input * p
462
- b = num_input * (1 - p)
463
- tar_val = target.value
464
- cond = u.math.logical_and(a > 5, b > 5)
465
-
466
- if indices is None:
467
- # generate Poisson input
468
- branch1 = jax.tree.map(
469
- lambda tar: random.normal(
470
- a,
471
- b * p,
472
- tar.shape,
473
- dtype=tar.dtype
474
- ),
475
- tar_val,
476
- is_leaf=u.math.is_quantity
477
- )
478
- branch2 = jax.tree.map(
479
- lambda tar: random.binomial(
480
- num_input,
481
- p,
482
- tar.shape,
483
- check_valid=False,
484
- dtype=tar.dtype
485
- ),
486
- tar_val,
487
- is_leaf=u.math.is_quantity,
488
- )
489
-
490
- inp = jax.tree.map(
491
- lambda b1, b2: u.math.where(cond, b1, b2),
492
- branch1,
493
- branch2,
494
- is_leaf=u.math.is_quantity,
495
- )
496
-
497
- # inp = jax.lax.cond(
498
- # cond,
499
- # lambda rand_key: jax.tree.map(
500
- # lambda tar: random.normal(
501
- # a,
502
- # b * p,
503
- # tar.shape,
504
- # key=rand_key,
505
- # dtype=tar.dtype
506
- # ),
507
- # tar_val,
508
- # is_leaf=u.math.is_quantity
509
- # ),
510
- # lambda rand_key: jax.tree.map(
511
- # lambda tar: random.binomial(
512
- # num_input,
513
- # p,
514
- # tar.shape,
515
- # key=rand_key,
516
- # check_valid=False,
517
- # dtype=tar.dtype
518
- # ),
519
- # tar_val,
520
- # is_leaf=u.math.is_quantity,
521
- # ),
522
- # random.split_key()
523
- # )
524
-
525
- # update target variable
526
- data = jax.tree.map(
527
- lambda tar, x: tar + x * weight,
528
- target.value,
529
- inp,
530
- is_leaf=u.math.is_quantity
531
- )
532
-
533
- else:
534
- # generate Poisson input
535
- branch1 = jax.tree.map(
536
- lambda tar: random.normal(
537
- a,
538
- b * p,
539
- tar[indices].shape,
540
- dtype=tar.dtype
541
- ),
542
- tar_val,
543
- is_leaf=u.math.is_quantity
544
- )
545
- branch2 = jax.tree.map(
546
- lambda tar: random.binomial(
547
- num_input,
548
- p,
549
- tar[indices].shape,
550
- check_valid=False,
551
- dtype=tar.dtype
552
- ),
553
- tar_val,
554
- is_leaf=u.math.is_quantity
555
- )
556
-
557
- inp = jax.tree.map(
558
- lambda b1, b2: u.math.where(cond, b1, b2),
559
- branch1,
560
- branch2,
561
- is_leaf=u.math.is_quantity,
562
- )
563
-
564
- # inp = jax.lax.cond(
565
- # cond,
566
- # lambda rand_key: jax.tree.map(
567
- # lambda tar: random.normal(
568
- # a,
569
- # b * p,
570
- # tar[indices].shape,
571
- # key=rand_key,
572
- # dtype=tar.dtype
573
- # ),
574
- # tar_val,
575
- # is_leaf=u.math.is_quantity
576
- # ),
577
- # lambda rand_key: jax.tree.map(
578
- # lambda tar: random.binomial(
579
- # num_input,
580
- # p,
581
- # tar[indices].shape,
582
- # key=rand_key,
583
- # check_valid=False,
584
- # dtype=tar.dtype
585
- # ),
586
- # tar_val,
587
- # is_leaf=u.math.is_quantity
588
- # ),
589
- # random.split_key()
590
- # )
591
-
592
- # update target variable
593
- data = jax.tree.map(
594
- lambda x, tar: tar.at[indices].add(x * weight),
595
- inp,
596
- tar_val,
597
- is_leaf=u.math.is_quantity
598
- )
599
-
600
- if refractory is not None:
601
- target.value = jax.tree.map(
602
- lambda x, tar: u.math.where(refractory, tar, x),
603
- data,
604
- tar_val,
605
- is_leaf=u.math.is_quantity
606
- )
607
- else:
608
- target.value = data
brainstate/nn/_ltp.py DELETED
@@ -1,28 +0,0 @@
1
- # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
-
19
- from ._synapse import Synapse
20
-
21
- __all__ = [
22
- 'LongTermPlasticity',
23
- ]
24
-
25
-
26
- class LongTermPlasticity(Synapse):
27
- pass
28
-